mirror of
https://github.com/home-assistant/core.git
synced 2025-07-11 15:27:08 +00:00
Anthropic conversation extended thinking support (#139662)
* Anthropic conversation extended thinking support * update conversation snapshots * Add conversation test * Update openai_conversation snapshots * Removed metadata * Removed metadata * Removed thinking * cosmetic fix * combine user messages * Apply suggestions from code review * Add tests for chat_log messages conversion * s/THINKING_BUDGET_TOKENS/THINKING_BUDGET/ * Apply suggestions from code review * Update tests * Update homeassistant/components/anthropic/strings.json Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * apply suggestions from code review --------- Co-authored-by: Robert Resch <robert@resch.dev> Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
baafcf48dc
commit
07e7672b78
@ -34,10 +34,12 @@ from .const import (
|
|||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_RECOMMENDED,
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
|
CONF_THINKING_BUDGET,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_THINKING_BUDGET,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -128,21 +130,29 @@ class AnthropicOptionsFlow(OptionsFlow):
|
|||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Manage the options."""
|
"""Manage the options."""
|
||||||
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
||||||
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||||
if user_input[CONF_LLM_HASS_API] == "none":
|
if user_input[CONF_LLM_HASS_API] == "none":
|
||||||
user_input.pop(CONF_LLM_HASS_API)
|
user_input.pop(CONF_LLM_HASS_API)
|
||||||
return self.async_create_entry(title="", data=user_input)
|
|
||||||
|
|
||||||
# Re-render the options again, now with the recommended options shown/hidden
|
if user_input.get(
|
||||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET
|
||||||
|
) >= user_input.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS):
|
||||||
|
errors[CONF_THINKING_BUDGET] = "thinking_budget_too_large"
|
||||||
|
|
||||||
options = {
|
if not errors:
|
||||||
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
return self.async_create_entry(title="", data=user_input)
|
||||||
CONF_PROMPT: user_input[CONF_PROMPT],
|
else:
|
||||||
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
# Re-render the options again, now with the recommended options shown/hidden
|
||||||
}
|
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||||
|
|
||||||
|
options = {
|
||||||
|
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||||
|
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||||
|
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||||
|
}
|
||||||
|
|
||||||
suggested_values = options.copy()
|
suggested_values = options.copy()
|
||||||
if not suggested_values.get(CONF_PROMPT):
|
if not suggested_values.get(CONF_PROMPT):
|
||||||
@ -156,6 +166,7 @@ class AnthropicOptionsFlow(OptionsFlow):
|
|||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="init",
|
step_id="init",
|
||||||
data_schema=schema,
|
data_schema=schema,
|
||||||
|
errors=errors or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -205,6 +216,10 @@ def anthropic_config_option_schema(
|
|||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
default=RECOMMENDED_TEMPERATURE,
|
default=RECOMMENDED_TEMPERATURE,
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_THINKING_BUDGET,
|
||||||
|
default=RECOMMENDED_THINKING_BUDGET,
|
||||||
|
): int,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return schema
|
return schema
|
||||||
|
@ -13,3 +13,8 @@ CONF_MAX_TOKENS = "max_tokens"
|
|||||||
RECOMMENDED_MAX_TOKENS = 1024
|
RECOMMENDED_MAX_TOKENS = 1024
|
||||||
CONF_TEMPERATURE = "temperature"
|
CONF_TEMPERATURE = "temperature"
|
||||||
RECOMMENDED_TEMPERATURE = 1.0
|
RECOMMENDED_TEMPERATURE = 1.0
|
||||||
|
CONF_THINKING_BUDGET = "thinking_budget"
|
||||||
|
RECOMMENDED_THINKING_BUDGET = 0
|
||||||
|
MIN_THINKING_BUDGET = 1024
|
||||||
|
|
||||||
|
THINKING_MODELS = ["claude-3-7-sonnet-20250219", "claude-3-7-sonnet-latest"]
|
||||||
|
@ -1,23 +1,32 @@
|
|||||||
"""Conversation support for Anthropic."""
|
"""Conversation support for Anthropic."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable, Iterable
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from anthropic import AsyncStream
|
from anthropic import AsyncStream
|
||||||
from anthropic._types import NOT_GIVEN
|
from anthropic._types import NOT_GIVEN
|
||||||
from anthropic.types import (
|
from anthropic.types import (
|
||||||
InputJSONDelta,
|
InputJSONDelta,
|
||||||
Message,
|
|
||||||
MessageParam,
|
MessageParam,
|
||||||
MessageStreamEvent,
|
MessageStreamEvent,
|
||||||
RawContentBlockDeltaEvent,
|
RawContentBlockDeltaEvent,
|
||||||
RawContentBlockStartEvent,
|
RawContentBlockStartEvent,
|
||||||
RawContentBlockStopEvent,
|
RawContentBlockStopEvent,
|
||||||
|
RawMessageStartEvent,
|
||||||
|
RawMessageStopEvent,
|
||||||
|
RedactedThinkingBlock,
|
||||||
|
RedactedThinkingBlockParam,
|
||||||
|
SignatureDelta,
|
||||||
TextBlock,
|
TextBlock,
|
||||||
TextBlockParam,
|
TextBlockParam,
|
||||||
TextDelta,
|
TextDelta,
|
||||||
|
ThinkingBlock,
|
||||||
|
ThinkingBlockParam,
|
||||||
|
ThinkingConfigDisabledParam,
|
||||||
|
ThinkingConfigEnabledParam,
|
||||||
|
ThinkingDelta,
|
||||||
ToolParam,
|
ToolParam,
|
||||||
ToolResultBlockParam,
|
ToolResultBlockParam,
|
||||||
ToolUseBlock,
|
ToolUseBlock,
|
||||||
@ -39,11 +48,15 @@ from .const import (
|
|||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
|
CONF_THINKING_BUDGET,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
|
MIN_THINKING_BUDGET,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_THINKING_BUDGET,
|
||||||
|
THINKING_MODELS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
@ -71,73 +84,101 @@ def _format_tool(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _message_convert(
|
def _convert_content(
|
||||||
message: Message,
|
chat_content: Iterable[conversation.Content],
|
||||||
) -> MessageParam:
|
) -> list[MessageParam]:
|
||||||
"""Convert from class to TypedDict."""
|
"""Transform HA chat_log content into Anthropic API format."""
|
||||||
param_content: list[TextBlockParam | ToolUseBlockParam] = []
|
messages: list[MessageParam] = []
|
||||||
|
|
||||||
for message_content in message.content:
|
for content in chat_content:
|
||||||
if isinstance(message_content, TextBlock):
|
if isinstance(content, conversation.ToolResultContent):
|
||||||
param_content.append(TextBlockParam(type="text", text=message_content.text))
|
tool_result_block = ToolResultBlockParam(
|
||||||
elif isinstance(message_content, ToolUseBlock):
|
type="tool_result",
|
||||||
param_content.append(
|
tool_use_id=content.tool_call_id,
|
||||||
ToolUseBlockParam(
|
content=json.dumps(content.tool_result),
|
||||||
type="tool_use",
|
|
||||||
id=message_content.id,
|
|
||||||
name=message_content.name,
|
|
||||||
input=message_content.input,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
if not messages or messages[-1]["role"] != "user":
|
||||||
return MessageParam(role=message.role, content=param_content)
|
messages.append(
|
||||||
|
MessageParam(
|
||||||
|
role="user",
|
||||||
def _convert_content(chat_content: conversation.Content) -> MessageParam:
|
content=[tool_result_block],
|
||||||
"""Create tool response content."""
|
|
||||||
if isinstance(chat_content, conversation.ToolResultContent):
|
|
||||||
return MessageParam(
|
|
||||||
role="user",
|
|
||||||
content=[
|
|
||||||
ToolResultBlockParam(
|
|
||||||
type="tool_result",
|
|
||||||
tool_use_id=chat_content.tool_call_id,
|
|
||||||
content=json.dumps(chat_content.tool_result),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
if isinstance(chat_content, conversation.AssistantContent):
|
|
||||||
return MessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=[
|
|
||||||
TextBlockParam(type="text", text=chat_content.content or ""),
|
|
||||||
*[
|
|
||||||
ToolUseBlockParam(
|
|
||||||
type="tool_use",
|
|
||||||
id=tool_call.id,
|
|
||||||
name=tool_call.tool_name,
|
|
||||||
input=tool_call.tool_args,
|
|
||||||
)
|
)
|
||||||
for tool_call in chat_content.tool_calls or ()
|
)
|
||||||
],
|
elif isinstance(messages[-1]["content"], str):
|
||||||
],
|
messages[-1]["content"] = [
|
||||||
)
|
TextBlockParam(type="text", text=messages[-1]["content"]),
|
||||||
if isinstance(chat_content, conversation.UserContent):
|
tool_result_block,
|
||||||
return MessageParam(
|
]
|
||||||
role="user",
|
else:
|
||||||
content=chat_content.content,
|
messages[-1]["content"].append(tool_result_block) # type: ignore[attr-defined]
|
||||||
)
|
elif isinstance(content, conversation.UserContent):
|
||||||
# Note: We don't pass SystemContent here as its passed to the API as the prompt
|
# Combine consequent user messages
|
||||||
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
if not messages or messages[-1]["role"] != "user":
|
||||||
|
messages.append(
|
||||||
|
MessageParam(
|
||||||
|
role="user",
|
||||||
|
content=content.content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(messages[-1]["content"], str):
|
||||||
|
messages[-1]["content"] = [
|
||||||
|
TextBlockParam(type="text", text=messages[-1]["content"]),
|
||||||
|
TextBlockParam(type="text", text=content.content),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages[-1]["content"].append( # type: ignore[attr-defined]
|
||||||
|
TextBlockParam(type="text", text=content.content)
|
||||||
|
)
|
||||||
|
elif isinstance(content, conversation.AssistantContent):
|
||||||
|
# Combine consequent assistant messages
|
||||||
|
if not messages or messages[-1]["role"] != "assistant":
|
||||||
|
messages.append(
|
||||||
|
MessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if content.content:
|
||||||
|
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||||
|
TextBlockParam(type="text", text=content.content)
|
||||||
|
)
|
||||||
|
if content.tool_calls:
|
||||||
|
messages[-1]["content"].extend( # type: ignore[union-attr]
|
||||||
|
[
|
||||||
|
ToolUseBlockParam(
|
||||||
|
type="tool_use",
|
||||||
|
id=tool_call.id,
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
input=tool_call.tool_args,
|
||||||
|
)
|
||||||
|
for tool_call in content.tool_calls
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Note: We don't pass SystemContent here as its passed to the API as the prompt
|
||||||
|
raise TypeError(f"Unexpected content type: {type(content)}")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def _transform_stream(
|
async def _transform_stream(
|
||||||
result: AsyncStream[MessageStreamEvent],
|
result: AsyncStream[MessageStreamEvent],
|
||||||
|
messages: list[MessageParam],
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
"""Transform the response stream into HA format.
|
"""Transform the response stream into HA format.
|
||||||
|
|
||||||
A typical stream of responses might look something like the following:
|
A typical stream of responses might look something like the following:
|
||||||
- RawMessageStartEvent with no content
|
- RawMessageStartEvent with no content
|
||||||
|
- RawContentBlockStartEvent with an empty ThinkingBlock (if extended thinking is enabled)
|
||||||
|
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||||||
|
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||||||
|
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||||||
|
- ...
|
||||||
|
- RawContentBlockDeltaEvent with a SignatureDelta
|
||||||
|
- RawContentBlockStopEvent
|
||||||
|
- RawContentBlockStartEvent with a RedactedThinkingBlock (occasionally)
|
||||||
|
- RawContentBlockStopEvent (RedactedThinkingBlock does not have a delta)
|
||||||
- RawContentBlockStartEvent with an empty TextBlock
|
- RawContentBlockStartEvent with an empty TextBlock
|
||||||
- RawContentBlockDeltaEvent with a TextDelta
|
- RawContentBlockDeltaEvent with a TextDelta
|
||||||
- RawContentBlockDeltaEvent with a TextDelta
|
- RawContentBlockDeltaEvent with a TextDelta
|
||||||
@ -151,44 +192,103 @@ async def _transform_stream(
|
|||||||
- RawContentBlockStopEvent
|
- RawContentBlockStopEvent
|
||||||
- RawMessageDeltaEvent with a stop_reason='tool_use'
|
- RawMessageDeltaEvent with a stop_reason='tool_use'
|
||||||
- RawMessageStopEvent(type='message_stop')
|
- RawMessageStopEvent(type='message_stop')
|
||||||
|
|
||||||
|
Each message could contain multiple blocks of the same type.
|
||||||
"""
|
"""
|
||||||
if result is None:
|
if result is None:
|
||||||
raise TypeError("Expected a stream of messages")
|
raise TypeError("Expected a stream of messages")
|
||||||
|
|
||||||
current_tool_call: dict | None = None
|
current_message: MessageParam | None = None
|
||||||
|
current_block: (
|
||||||
|
TextBlockParam
|
||||||
|
| ToolUseBlockParam
|
||||||
|
| ThinkingBlockParam
|
||||||
|
| RedactedThinkingBlockParam
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
|
current_tool_args: str
|
||||||
|
|
||||||
async for response in result:
|
async for response in result:
|
||||||
LOGGER.debug("Received response: %s", response)
|
LOGGER.debug("Received response: %s", response)
|
||||||
|
|
||||||
if isinstance(response, RawContentBlockStartEvent):
|
if isinstance(response, RawMessageStartEvent):
|
||||||
|
if response.message.role != "assistant":
|
||||||
|
raise ValueError("Unexpected message role")
|
||||||
|
current_message = MessageParam(role=response.message.role, content=[])
|
||||||
|
elif isinstance(response, RawContentBlockStartEvent):
|
||||||
if isinstance(response.content_block, ToolUseBlock):
|
if isinstance(response.content_block, ToolUseBlock):
|
||||||
current_tool_call = {
|
current_block = ToolUseBlockParam(
|
||||||
"id": response.content_block.id,
|
type="tool_use",
|
||||||
"name": response.content_block.name,
|
id=response.content_block.id,
|
||||||
"input": "",
|
name=response.content_block.name,
|
||||||
}
|
input="",
|
||||||
|
)
|
||||||
|
current_tool_args = ""
|
||||||
elif isinstance(response.content_block, TextBlock):
|
elif isinstance(response.content_block, TextBlock):
|
||||||
|
current_block = TextBlockParam(
|
||||||
|
type="text", text=response.content_block.text
|
||||||
|
)
|
||||||
yield {"role": "assistant"}
|
yield {"role": "assistant"}
|
||||||
|
if response.content_block.text:
|
||||||
|
yield {"content": response.content_block.text}
|
||||||
|
elif isinstance(response.content_block, ThinkingBlock):
|
||||||
|
current_block = ThinkingBlockParam(
|
||||||
|
type="thinking",
|
||||||
|
thinking=response.content_block.thinking,
|
||||||
|
signature=response.content_block.signature,
|
||||||
|
)
|
||||||
|
elif isinstance(response.content_block, RedactedThinkingBlock):
|
||||||
|
current_block = RedactedThinkingBlockParam(
|
||||||
|
type="redacted_thinking", data=response.content_block.data
|
||||||
|
)
|
||||||
|
LOGGER.debug(
|
||||||
|
"Some of Claude’s internal reasoning has been automatically "
|
||||||
|
"encrypted for safety reasons. This doesn’t affect the quality of "
|
||||||
|
"responses"
|
||||||
|
)
|
||||||
elif isinstance(response, RawContentBlockDeltaEvent):
|
elif isinstance(response, RawContentBlockDeltaEvent):
|
||||||
|
if current_block is None:
|
||||||
|
raise ValueError("Unexpected delta without a block")
|
||||||
if isinstance(response.delta, InputJSONDelta):
|
if isinstance(response.delta, InputJSONDelta):
|
||||||
if current_tool_call is None:
|
current_tool_args += response.delta.partial_json
|
||||||
raise ValueError("Unexpected delta without a tool call")
|
|
||||||
current_tool_call["input"] += response.delta.partial_json
|
|
||||||
elif isinstance(response.delta, TextDelta):
|
elif isinstance(response.delta, TextDelta):
|
||||||
LOGGER.debug("yielding delta: %s", response.delta.text)
|
text_block = cast(TextBlockParam, current_block)
|
||||||
|
text_block["text"] += response.delta.text
|
||||||
yield {"content": response.delta.text}
|
yield {"content": response.delta.text}
|
||||||
|
elif isinstance(response.delta, ThinkingDelta):
|
||||||
|
thinking_block = cast(ThinkingBlockParam, current_block)
|
||||||
|
thinking_block["thinking"] += response.delta.thinking
|
||||||
|
elif isinstance(response.delta, SignatureDelta):
|
||||||
|
thinking_block = cast(ThinkingBlockParam, current_block)
|
||||||
|
thinking_block["signature"] += response.delta.signature
|
||||||
elif isinstance(response, RawContentBlockStopEvent):
|
elif isinstance(response, RawContentBlockStopEvent):
|
||||||
if current_tool_call:
|
if current_block is None:
|
||||||
|
raise ValueError("Unexpected stop event without a current block")
|
||||||
|
if current_block["type"] == "tool_use":
|
||||||
|
tool_block = cast(ToolUseBlockParam, current_block)
|
||||||
|
tool_args = json.loads(current_tool_args)
|
||||||
|
tool_block["input"] = tool_args
|
||||||
yield {
|
yield {
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
id=current_tool_call["id"],
|
id=tool_block["id"],
|
||||||
tool_name=current_tool_call["name"],
|
tool_name=tool_block["name"],
|
||||||
tool_args=json.loads(current_tool_call["input"]),
|
tool_args=tool_args,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
current_tool_call = None
|
elif current_block["type"] == "thinking":
|
||||||
|
thinking_block = cast(ThinkingBlockParam, current_block)
|
||||||
|
LOGGER.debug("Thinking: %s", thinking_block["thinking"])
|
||||||
|
|
||||||
|
if current_message is None:
|
||||||
|
raise ValueError("Unexpected stop event without a current message")
|
||||||
|
current_message["content"].append(current_block) # type: ignore[union-attr]
|
||||||
|
current_block = None
|
||||||
|
elif isinstance(response, RawMessageStopEvent):
|
||||||
|
if current_message is not None:
|
||||||
|
messages.append(current_message)
|
||||||
|
current_message = None
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConversationEntity(
|
class AnthropicConversationEntity(
|
||||||
@ -254,34 +354,50 @@ class AnthropicConversationEntity(
|
|||||||
system = chat_log.content[0]
|
system = chat_log.content[0]
|
||||||
if not isinstance(system, conversation.SystemContent):
|
if not isinstance(system, conversation.SystemContent):
|
||||||
raise TypeError("First message must be a system message")
|
raise TypeError("First message must be a system message")
|
||||||
messages = [_convert_content(content) for content in chat_log.content[1:]]
|
messages = _convert_content(chat_log.content[1:])
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
|
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
|
||||||
|
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
model_args = {
|
||||||
stream = await client.messages.create(
|
"model": model,
|
||||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
"messages": messages,
|
||||||
messages=messages,
|
"tools": tools or NOT_GIVEN,
|
||||||
tools=tools or NOT_GIVEN,
|
"max_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
"system": system.content,
|
||||||
system=system.content,
|
"stream": True,
|
||||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
}
|
||||||
stream=True,
|
if model in THINKING_MODELS and thinking_budget >= MIN_THINKING_BUDGET:
|
||||||
|
model_args["thinking"] = ThinkingConfigEnabledParam(
|
||||||
|
type="enabled", budget_tokens=thinking_budget
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
|
||||||
|
model_args["temperature"] = options.get(
|
||||||
|
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await client.messages.create(**model_args)
|
||||||
except anthropic.AnthropicError as err:
|
except anthropic.AnthropicError as err:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"Sorry, I had a problem talking to Anthropic: {err}"
|
f"Sorry, I had a problem talking to Anthropic: {err}"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
messages.extend(
|
messages.extend(
|
||||||
[
|
_convert_content(
|
||||||
_convert_content(content)
|
[
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
content
|
||||||
user_input.agent_id, _transform_stream(stream)
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
)
|
user_input.agent_id, _transform_stream(stream, messages)
|
||||||
]
|
)
|
||||||
|
if not isinstance(content, conversation.AssistantContent)
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not chat_log.unresponded_tool_results:
|
if not chat_log.unresponded_tool_results:
|
||||||
|
@ -23,12 +23,17 @@
|
|||||||
"max_tokens": "Maximum tokens to return in response",
|
"max_tokens": "Maximum tokens to return in response",
|
||||||
"temperature": "Temperature",
|
"temperature": "Temperature",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||||
"recommended": "Recommended model settings"
|
"recommended": "Recommended model settings",
|
||||||
|
"thinking_budget_tokens": "Thinking budget"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||||
|
"thinking_budget_tokens": "The number of tokens the model can use to think about the response out of the total maximum number of tokens. Set to 1024 or greater to enable extended thinking."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"thinking_budget_too_large": "Maximum tokens must be greater than the thinking budget."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.anthropic import CONF_CHAT_MODEL
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
@ -38,6 +39,21 @@ def mock_config_entry_with_assist(
|
|||||||
return mock_config_entry
|
return mock_config_entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_entry_with_extended_thinking(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
) -> MockConfigEntry:
|
||||||
|
"""Mock a config entry with assist."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={
|
||||||
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||||
|
CONF_CHAT_MODEL: "claude-3-7-sonnet-latest",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return mock_config_entry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_init_component(
|
async def mock_init_component(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
@ -1,4 +1,321 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_extended_thinking_tool_call
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': '''
|
||||||
|
Current time is 16:00:00. Today's date is 2024-06-03.
|
||||||
|
You are a voice assistant for Home Assistant.
|
||||||
|
Answer questions about the world truthfully.
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.
|
||||||
|
''',
|
||||||
|
'role': 'system',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': 'Please call the test function',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.claude',
|
||||||
|
'content': 'Certainly, calling it now!',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'test_value',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.claude',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'Test response',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.claude',
|
||||||
|
'content': 'I have successfully called the function',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_extended_thinking_tool_call.1
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': 'Please call the test function',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'signature': 'ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==',
|
||||||
|
'thinking': 'The user asked me to call a test function.Is it a test? What would the function do? Would it violate any privacy or security policies?',
|
||||||
|
'type': 'thinking',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': 'EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9KWPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeVsJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOKiKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny',
|
||||||
|
'type': 'redacted_thinking',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'signature': 'ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==',
|
||||||
|
'thinking': "Okay, let's give it a shot. Will I pass the test?",
|
||||||
|
'type': 'thinking',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'text': 'Certainly, calling it now!',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||||
|
'input': dict({
|
||||||
|
'param1': 'test_value',
|
||||||
|
}),
|
||||||
|
'name': 'test_tool',
|
||||||
|
'type': 'tool_use',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'content': '"Test response"',
|
||||||
|
'tool_use_id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||||
|
'type': 'tool_result',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'I have successfully called the function',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_history_conversion[content0]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': 'Are you sure?',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Yes, I am sure!',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_history_conversion[content1]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': 'What shape is a donut?',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'A donut is a torus.',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': 'Are you sure?',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Yes, I am sure!',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_history_conversion[content2]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'What shape is a donut?',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'text': 'Can you tell me?',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'A donut is a torus.',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'text': 'Hope this helps.',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': 'Are you sure?',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Yes, I am sure!',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_history_conversion[content3]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'What shape is a donut?',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'text': 'Can you tell me?',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'text': 'Please?',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'A donut is a torus.',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'text': 'Hope this helps.',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'text': 'You are welcome.',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': 'Are you sure?',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Yes, I am sure!',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_history_conversion[content4]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'content': 'Turn off the lights and make me coffee',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Sure.',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'mock-tool-call-id',
|
||||||
|
'input': dict({
|
||||||
|
'domain': 'light',
|
||||||
|
}),
|
||||||
|
'name': 'HassTurnOff',
|
||||||
|
'type': 'tool_use',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'id': 'mock-tool-call-id-2',
|
||||||
|
'input': dict({
|
||||||
|
}),
|
||||||
|
'name': 'MakeCoffee',
|
||||||
|
'type': 'tool_use',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Thank you',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': '{"success": true, "response": "Lights are off."}',
|
||||||
|
'tool_use_id': 'mock-tool-call-id',
|
||||||
|
'type': 'tool_result',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': '{"success": false, "response": "Not enough milk."}',
|
||||||
|
'tool_use_id': 'mock-tool-call-id-2',
|
||||||
|
'type': 'tool_result',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Should I add milk to the shopping list?',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': 'Are you sure?',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'content': list([
|
||||||
|
dict({
|
||||||
|
'text': 'Yes, I am sure!',
|
||||||
|
'type': 'text',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'role': 'assistant',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_unknown_hass_api
|
# name: test_unknown_hass_api
|
||||||
dict({
|
dict({
|
||||||
'continue_conversation': False,
|
'continue_conversation': False,
|
||||||
|
@ -21,9 +21,11 @@ from homeassistant.components.anthropic.const import (
|
|||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_RECOMMENDED,
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
|
CONF_THINKING_BUDGET,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_THINKING_BUDGET,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -94,6 +96,28 @@ async def test_options(
|
|||||||
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
async def test_options_thinking_budget_more_than_max(
|
||||||
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test error about thinking budget being more than max tokens."""
|
||||||
|
options_flow = await hass.config_entries.options.async_init(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
options = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
{
|
||||||
|
"prompt": "Speak like a pirate",
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"chat_model": "claude-3-7-sonnet-latest",
|
||||||
|
"temperature": 1,
|
||||||
|
"thinking_budget": 16384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert options["type"] is FlowResultType.FORM
|
||||||
|
assert options["errors"] == {"thinking_budget": "thinking_budget_too_large"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
[
|
[
|
||||||
@ -186,6 +210,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TEMPERATURE: 0.3,
|
CONF_TEMPERATURE: 0.3,
|
||||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -195,6 +220,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TEMPERATURE: 0.3,
|
CONF_TEMPERATURE: 0.3,
|
||||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
CONF_RECOMMENDED: True,
|
CONF_RECOMMENDED: True,
|
||||||
|
@ -14,13 +14,18 @@ from anthropic.types import (
|
|||||||
RawMessageStartEvent,
|
RawMessageStartEvent,
|
||||||
RawMessageStopEvent,
|
RawMessageStopEvent,
|
||||||
RawMessageStreamEvent,
|
RawMessageStreamEvent,
|
||||||
|
RedactedThinkingBlock,
|
||||||
|
SignatureDelta,
|
||||||
TextBlock,
|
TextBlock,
|
||||||
TextDelta,
|
TextDelta,
|
||||||
|
ThinkingBlock,
|
||||||
|
ThinkingDelta,
|
||||||
ToolUseBlock,
|
ToolUseBlock,
|
||||||
Usage,
|
Usage,
|
||||||
)
|
)
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from httpx import URL, Request, Response
|
from httpx import URL, Request, Response
|
||||||
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -28,7 +33,7 @@ from homeassistant.components import conversation
|
|||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent, llm
|
from homeassistant.helpers import chat_session, intent, llm
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util import ulid as ulid_util
|
from homeassistant.util import ulid as ulid_util
|
||||||
|
|
||||||
@ -86,6 +91,57 @@ def create_content_block(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_thinking_block(
|
||||||
|
index: int, thinking_parts: list[str]
|
||||||
|
) -> list[RawMessageStreamEvent]:
|
||||||
|
"""Create a thinking block with the specified deltas."""
|
||||||
|
return [
|
||||||
|
RawContentBlockStartEvent(
|
||||||
|
type="content_block_start",
|
||||||
|
content_block=ThinkingBlock(signature="", thinking="", type="thinking"),
|
||||||
|
index=index,
|
||||||
|
),
|
||||||
|
*[
|
||||||
|
RawContentBlockDeltaEvent(
|
||||||
|
delta=ThinkingDelta(thinking=thinking_part, type="thinking_delta"),
|
||||||
|
index=index,
|
||||||
|
type="content_block_delta",
|
||||||
|
)
|
||||||
|
for thinking_part in thinking_parts
|
||||||
|
],
|
||||||
|
RawContentBlockDeltaEvent(
|
||||||
|
delta=SignatureDelta(
|
||||||
|
signature="ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/N"
|
||||||
|
"oB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ"
|
||||||
|
"4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo"
|
||||||
|
"21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==",
|
||||||
|
type="signature_delta",
|
||||||
|
),
|
||||||
|
index=index,
|
||||||
|
type="content_block_delta",
|
||||||
|
),
|
||||||
|
RawContentBlockStopEvent(index=index, type="content_block_stop"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_redacted_thinking_block(index: int) -> list[RawMessageStreamEvent]:
|
||||||
|
"""Create a redacted thinking block."""
|
||||||
|
return [
|
||||||
|
RawContentBlockStartEvent(
|
||||||
|
type="content_block_start",
|
||||||
|
content_block=RedactedThinkingBlock(
|
||||||
|
data="EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9K"
|
||||||
|
"WPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeV"
|
||||||
|
"sJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOK"
|
||||||
|
"iKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny",
|
||||||
|
type="redacted_thinking",
|
||||||
|
),
|
||||||
|
index=index,
|
||||||
|
),
|
||||||
|
RawContentBlockStopEvent(index=index, type="content_block_stop"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_tool_use_block(
|
def create_tool_use_block(
|
||||||
index: int, tool_id: str, tool_name: str, json_parts: list[str]
|
index: int, tool_id: str, tool_name: str, json_parts: list[str]
|
||||||
) -> list[RawMessageStreamEvent]:
|
) -> list[RawMessageStreamEvent]:
|
||||||
@ -381,7 +437,7 @@ async def test_function_exception(
|
|||||||
return stream_generator(
|
return stream_generator(
|
||||||
create_messages(
|
create_messages(
|
||||||
[
|
[
|
||||||
*create_content_block(0, "Certainly, calling it now!"),
|
*create_content_block(0, ["Certainly, calling it now!"]),
|
||||||
*create_tool_use_block(
|
*create_tool_use_block(
|
||||||
1,
|
1,
|
||||||
"toolu_0123456789AbCdEfGhIjKlM",
|
"toolu_0123456789AbCdEfGhIjKlM",
|
||||||
@ -464,7 +520,7 @@ async def test_assist_api_tools_conversion(
|
|||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=stream_generator(
|
return_value=stream_generator(
|
||||||
create_messages(
|
create_messages(
|
||||||
create_content_block(0, "Hello, how can I help you?"),
|
create_content_block(0, ["Hello, how can I help you?"]),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
) as mock_create:
|
) as mock_create:
|
||||||
@ -509,7 +565,7 @@ async def test_conversation_id(
|
|||||||
def create_stream_generator(*args, **kwargs) -> Any:
|
def create_stream_generator(*args, **kwargs) -> Any:
|
||||||
return stream_generator(
|
return stream_generator(
|
||||||
create_messages(
|
create_messages(
|
||||||
create_content_block(0, "Hello, how can I help you?"),
|
create_content_block(0, ["Hello, how can I help you?"]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -547,3 +603,283 @@ async def test_conversation_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.conversation_id == "koala"
|
assert result.conversation_id == "koala"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_extended_thinking(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_extended_thinking: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test extended thinking support."""
|
||||||
|
with patch(
|
||||||
|
"anthropic.resources.messages.AsyncMessages.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=stream_generator(
|
||||||
|
create_messages(
|
||||||
|
[
|
||||||
|
*create_thinking_block(
|
||||||
|
0,
|
||||||
|
[
|
||||||
|
"The user has just",
|
||||||
|
' greeted me with "Hi".',
|
||||||
|
" This is a simple greeting an",
|
||||||
|
"d doesn't require any Home Assistant function",
|
||||||
|
" calls. I should respond with",
|
||||||
|
" a friendly greeting and let them know I'm available",
|
||||||
|
" to help with their smart home.",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
*create_content_block(1, ["Hello, how can I help you today?"]),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id="conversation.claude"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
|
||||||
|
result.conversation_id
|
||||||
|
)
|
||||||
|
assert len(chat_log.content) == 3
|
||||||
|
assert chat_log.content[1].content == "hello"
|
||||||
|
assert chat_log.content[2].content == "Hello, how can I help you today?"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_redacted_thinking(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_extended_thinking: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test extended thinking with redacted thinking blocks."""
|
||||||
|
with patch(
|
||||||
|
"anthropic.resources.messages.AsyncMessages.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=stream_generator(
|
||||||
|
create_messages(
|
||||||
|
[
|
||||||
|
*create_redacted_thinking_block(0),
|
||||||
|
*create_redacted_thinking_block(1),
|
||||||
|
*create_redacted_thinking_block(2),
|
||||||
|
*create_content_block(3, ["How can I help you today?"]),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A9"
|
||||||
|
"8432ECCCE4C1253D5E2D82641AC0E52CC2876CB",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id="conversation.claude",
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
|
||||||
|
result.conversation_id
|
||||||
|
)
|
||||||
|
assert len(chat_log.content) == 3
|
||||||
|
assert chat_log.content[2].content == "How can I help you today?"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools")
|
||||||
|
async def test_extended_thinking_tool_call(
|
||||||
|
mock_get_tools,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_extended_thinking: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that thinking blocks and their order are preserved in with tool calls."""
|
||||||
|
agent_id = "conversation.claude"
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema(
|
||||||
|
{vol.Optional("param1", description="Test parameters"): str}
|
||||||
|
)
|
||||||
|
mock_tool.async_call.return_value = "Test response"
|
||||||
|
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
|
def completion_result(*args, messages, **kwargs):
|
||||||
|
for message in messages:
|
||||||
|
for content in message["content"]:
|
||||||
|
if not isinstance(content, str) and content["type"] == "tool_use":
|
||||||
|
return stream_generator(
|
||||||
|
create_messages(
|
||||||
|
create_content_block(
|
||||||
|
0, ["I have ", "successfully called ", "the function"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return stream_generator(
|
||||||
|
create_messages(
|
||||||
|
[
|
||||||
|
*create_thinking_block(
|
||||||
|
0,
|
||||||
|
[
|
||||||
|
"The user asked me to",
|
||||||
|
" call a test function.",
|
||||||
|
"Is it a test? What",
|
||||||
|
" would the function",
|
||||||
|
" do? Would it violate",
|
||||||
|
" any privacy or security",
|
||||||
|
" policies?",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
*create_redacted_thinking_block(1),
|
||||||
|
*create_thinking_block(
|
||||||
|
2, ["Okay, let's give it a shot.", " Will I pass the test?"]
|
||||||
|
),
|
||||||
|
*create_content_block(3, ["Certainly, calling it now!"]),
|
||||||
|
*create_tool_use_block(
|
||||||
|
1,
|
||||||
|
"toolu_0123456789AbCdEfGhIjKlM",
|
||||||
|
"test_tool",
|
||||||
|
['{"para', 'm1": "test_valu', 'e"}'],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"anthropic.resources.messages.AsyncMessages.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=completion_result,
|
||||||
|
) as mock_create,
|
||||||
|
freeze_time("2024-06-03 23:00:00"),
|
||||||
|
):
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
None,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
|
||||||
|
result.conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chat_log.content == snapshot
|
||||||
|
assert mock_create.mock_calls[1][2]["messages"] == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
conversation.chat_log.SystemContent("You are a helpful assistant."),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
conversation.chat_log.SystemContent("You are a helpful assistant."),
|
||||||
|
conversation.chat_log.UserContent("What shape is a donut?"),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude", content="A donut is a torus."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
conversation.chat_log.SystemContent("You are a helpful assistant."),
|
||||||
|
conversation.chat_log.UserContent("What shape is a donut?"),
|
||||||
|
conversation.chat_log.UserContent("Can you tell me?"),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude", content="A donut is a torus."
|
||||||
|
),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude", content="Hope this helps."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
conversation.chat_log.SystemContent("You are a helpful assistant."),
|
||||||
|
conversation.chat_log.UserContent("What shape is a donut?"),
|
||||||
|
conversation.chat_log.UserContent("Can you tell me?"),
|
||||||
|
conversation.chat_log.UserContent("Please?"),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude", content="A donut is a torus."
|
||||||
|
),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude", content="Hope this helps."
|
||||||
|
),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude", content="You are welcome."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
conversation.chat_log.SystemContent("You are a helpful assistant."),
|
||||||
|
conversation.chat_log.UserContent("Turn off the lights and make me coffee"),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude",
|
||||||
|
content="Sure.",
|
||||||
|
tool_calls=[
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id",
|
||||||
|
tool_name="HassTurnOff",
|
||||||
|
tool_args={"domain": "light"},
|
||||||
|
),
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id-2",
|
||||||
|
tool_name="MakeCoffee",
|
||||||
|
tool_args={},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
conversation.chat_log.UserContent("Thank you"),
|
||||||
|
conversation.chat_log.ToolResultContent(
|
||||||
|
agent_id="conversation.claude",
|
||||||
|
tool_call_id="mock-tool-call-id",
|
||||||
|
tool_name="HassTurnOff",
|
||||||
|
tool_result={"success": True, "response": "Lights are off."},
|
||||||
|
),
|
||||||
|
conversation.chat_log.ToolResultContent(
|
||||||
|
agent_id="conversation.claude",
|
||||||
|
tool_call_id="mock-tool-call-id-2",
|
||||||
|
tool_name="MakeCoffee",
|
||||||
|
tool_result={"success": False, "response": "Not enough milk."},
|
||||||
|
),
|
||||||
|
conversation.chat_log.AssistantContent(
|
||||||
|
agent_id="conversation.claude",
|
||||||
|
content="Should I add milk to the shopping list?",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_history_conversion(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
content: list[conversation.chat_log.Content],
|
||||||
|
) -> None:
|
||||||
|
"""Test conversion of chat_log entries into API parameters."""
|
||||||
|
conversation_id = "conversation_id"
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(hass, conversation_id) as session,
|
||||||
|
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||||
|
patch(
|
||||||
|
"anthropic.resources.messages.AsyncMessages.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=stream_generator(
|
||||||
|
create_messages(
|
||||||
|
[
|
||||||
|
*create_content_block(0, ["Yes, I am sure!"]),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
) as mock_create,
|
||||||
|
):
|
||||||
|
chat_log.content = content
|
||||||
|
|
||||||
|
await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Are you sure?",
|
||||||
|
conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id="conversation.claude",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user