mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Split Anthropic entity (#147770)
This commit is contained in:
parent
bf74ba990a
commit
38a7b21052
@ -1,69 +1,17 @@
|
|||||||
"""Conversation support for Anthropic."""
|
"""Conversation support for Anthropic."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable, Iterable
|
from typing import Literal
|
||||||
import json
|
|
||||||
from typing import Any, Literal, cast
|
|
||||||
|
|
||||||
import anthropic
|
|
||||||
from anthropic import AsyncStream
|
|
||||||
from anthropic._types import NOT_GIVEN
|
|
||||||
from anthropic.types import (
|
|
||||||
InputJSONDelta,
|
|
||||||
MessageDeltaUsage,
|
|
||||||
MessageParam,
|
|
||||||
MessageStreamEvent,
|
|
||||||
RawContentBlockDeltaEvent,
|
|
||||||
RawContentBlockStartEvent,
|
|
||||||
RawContentBlockStopEvent,
|
|
||||||
RawMessageDeltaEvent,
|
|
||||||
RawMessageStartEvent,
|
|
||||||
RawMessageStopEvent,
|
|
||||||
RedactedThinkingBlock,
|
|
||||||
RedactedThinkingBlockParam,
|
|
||||||
SignatureDelta,
|
|
||||||
TextBlock,
|
|
||||||
TextBlockParam,
|
|
||||||
TextDelta,
|
|
||||||
ThinkingBlock,
|
|
||||||
ThinkingBlockParam,
|
|
||||||
ThinkingConfigDisabledParam,
|
|
||||||
ThinkingConfigEnabledParam,
|
|
||||||
ThinkingDelta,
|
|
||||||
ToolParam,
|
|
||||||
ToolResultBlockParam,
|
|
||||||
ToolUseBlock,
|
|
||||||
ToolUseBlockParam,
|
|
||||||
Usage,
|
|
||||||
)
|
|
||||||
from voluptuous_openapi import convert
|
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from . import AnthropicConfigEntry
|
from . import AnthropicConfigEntry
|
||||||
from .const import (
|
from .const import CONF_PROMPT, DOMAIN
|
||||||
CONF_CHAT_MODEL,
|
from .entity import AnthropicBaseLLMEntity
|
||||||
CONF_MAX_TOKENS,
|
|
||||||
CONF_PROMPT,
|
|
||||||
CONF_TEMPERATURE,
|
|
||||||
CONF_THINKING_BUDGET,
|
|
||||||
DOMAIN,
|
|
||||||
LOGGER,
|
|
||||||
MIN_THINKING_BUDGET,
|
|
||||||
RECOMMENDED_CHAT_MODEL,
|
|
||||||
RECOMMENDED_MAX_TOKENS,
|
|
||||||
RECOMMENDED_TEMPERATURE,
|
|
||||||
RECOMMENDED_THINKING_BUDGET,
|
|
||||||
THINKING_MODELS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
|
||||||
MAX_TOOL_ITERATIONS = 10
|
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
@ -82,253 +30,10 @@ async def async_setup_entry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(
|
|
||||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
|
||||||
) -> ToolParam:
|
|
||||||
"""Format tool specification."""
|
|
||||||
return ToolParam(
|
|
||||||
name=tool.name,
|
|
||||||
description=tool.description or "",
|
|
||||||
input_schema=convert(tool.parameters, custom_serializer=custom_serializer),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_content(
|
|
||||||
chat_content: Iterable[conversation.Content],
|
|
||||||
) -> list[MessageParam]:
|
|
||||||
"""Transform HA chat_log content into Anthropic API format."""
|
|
||||||
messages: list[MessageParam] = []
|
|
||||||
|
|
||||||
for content in chat_content:
|
|
||||||
if isinstance(content, conversation.ToolResultContent):
|
|
||||||
tool_result_block = ToolResultBlockParam(
|
|
||||||
type="tool_result",
|
|
||||||
tool_use_id=content.tool_call_id,
|
|
||||||
content=json.dumps(content.tool_result),
|
|
||||||
)
|
|
||||||
if not messages or messages[-1]["role"] != "user":
|
|
||||||
messages.append(
|
|
||||||
MessageParam(
|
|
||||||
role="user",
|
|
||||||
content=[tool_result_block],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(messages[-1]["content"], str):
|
|
||||||
messages[-1]["content"] = [
|
|
||||||
TextBlockParam(type="text", text=messages[-1]["content"]),
|
|
||||||
tool_result_block,
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
messages[-1]["content"].append(tool_result_block) # type: ignore[attr-defined]
|
|
||||||
elif isinstance(content, conversation.UserContent):
|
|
||||||
# Combine consequent user messages
|
|
||||||
if not messages or messages[-1]["role"] != "user":
|
|
||||||
messages.append(
|
|
||||||
MessageParam(
|
|
||||||
role="user",
|
|
||||||
content=content.content,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(messages[-1]["content"], str):
|
|
||||||
messages[-1]["content"] = [
|
|
||||||
TextBlockParam(type="text", text=messages[-1]["content"]),
|
|
||||||
TextBlockParam(type="text", text=content.content),
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
messages[-1]["content"].append( # type: ignore[attr-defined]
|
|
||||||
TextBlockParam(type="text", text=content.content)
|
|
||||||
)
|
|
||||||
elif isinstance(content, conversation.AssistantContent):
|
|
||||||
# Combine consequent assistant messages
|
|
||||||
if not messages or messages[-1]["role"] != "assistant":
|
|
||||||
messages.append(
|
|
||||||
MessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=[],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if content.content:
|
|
||||||
messages[-1]["content"].append( # type: ignore[union-attr]
|
|
||||||
TextBlockParam(type="text", text=content.content)
|
|
||||||
)
|
|
||||||
if content.tool_calls:
|
|
||||||
messages[-1]["content"].extend( # type: ignore[union-attr]
|
|
||||||
[
|
|
||||||
ToolUseBlockParam(
|
|
||||||
type="tool_use",
|
|
||||||
id=tool_call.id,
|
|
||||||
name=tool_call.tool_name,
|
|
||||||
input=tool_call.tool_args,
|
|
||||||
)
|
|
||||||
for tool_call in content.tool_calls
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Note: We don't pass SystemContent here as its passed to the API as the prompt
|
|
||||||
raise TypeError(f"Unexpected content type: {type(content)}")
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
|
|
||||||
chat_log: conversation.ChatLog,
|
|
||||||
result: AsyncStream[MessageStreamEvent],
|
|
||||||
messages: list[MessageParam],
|
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
|
||||||
"""Transform the response stream into HA format.
|
|
||||||
|
|
||||||
A typical stream of responses might look something like the following:
|
|
||||||
- RawMessageStartEvent with no content
|
|
||||||
- RawContentBlockStartEvent with an empty ThinkingBlock (if extended thinking is enabled)
|
|
||||||
- RawContentBlockDeltaEvent with a ThinkingDelta
|
|
||||||
- RawContentBlockDeltaEvent with a ThinkingDelta
|
|
||||||
- RawContentBlockDeltaEvent with a ThinkingDelta
|
|
||||||
- ...
|
|
||||||
- RawContentBlockDeltaEvent with a SignatureDelta
|
|
||||||
- RawContentBlockStopEvent
|
|
||||||
- RawContentBlockStartEvent with a RedactedThinkingBlock (occasionally)
|
|
||||||
- RawContentBlockStopEvent (RedactedThinkingBlock does not have a delta)
|
|
||||||
- RawContentBlockStartEvent with an empty TextBlock
|
|
||||||
- RawContentBlockDeltaEvent with a TextDelta
|
|
||||||
- RawContentBlockDeltaEvent with a TextDelta
|
|
||||||
- RawContentBlockDeltaEvent with a TextDelta
|
|
||||||
- ...
|
|
||||||
- RawContentBlockStopEvent
|
|
||||||
- RawContentBlockStartEvent with ToolUseBlock specifying the function name
|
|
||||||
- RawContentBlockDeltaEvent with a InputJSONDelta
|
|
||||||
- RawContentBlockDeltaEvent with a InputJSONDelta
|
|
||||||
- ...
|
|
||||||
- RawContentBlockStopEvent
|
|
||||||
- RawMessageDeltaEvent with a stop_reason='tool_use'
|
|
||||||
- RawMessageStopEvent(type='message_stop')
|
|
||||||
|
|
||||||
Each message could contain multiple blocks of the same type.
|
|
||||||
"""
|
|
||||||
if result is None:
|
|
||||||
raise TypeError("Expected a stream of messages")
|
|
||||||
|
|
||||||
current_message: MessageParam | None = None
|
|
||||||
current_block: (
|
|
||||||
TextBlockParam
|
|
||||||
| ToolUseBlockParam
|
|
||||||
| ThinkingBlockParam
|
|
||||||
| RedactedThinkingBlockParam
|
|
||||||
| None
|
|
||||||
) = None
|
|
||||||
current_tool_args: str
|
|
||||||
input_usage: Usage | None = None
|
|
||||||
|
|
||||||
async for response in result:
|
|
||||||
LOGGER.debug("Received response: %s", response)
|
|
||||||
|
|
||||||
if isinstance(response, RawMessageStartEvent):
|
|
||||||
if response.message.role != "assistant":
|
|
||||||
raise ValueError("Unexpected message role")
|
|
||||||
current_message = MessageParam(role=response.message.role, content=[])
|
|
||||||
input_usage = response.message.usage
|
|
||||||
elif isinstance(response, RawContentBlockStartEvent):
|
|
||||||
if isinstance(response.content_block, ToolUseBlock):
|
|
||||||
current_block = ToolUseBlockParam(
|
|
||||||
type="tool_use",
|
|
||||||
id=response.content_block.id,
|
|
||||||
name=response.content_block.name,
|
|
||||||
input="",
|
|
||||||
)
|
|
||||||
current_tool_args = ""
|
|
||||||
elif isinstance(response.content_block, TextBlock):
|
|
||||||
current_block = TextBlockParam(
|
|
||||||
type="text", text=response.content_block.text
|
|
||||||
)
|
|
||||||
yield {"role": "assistant"}
|
|
||||||
if response.content_block.text:
|
|
||||||
yield {"content": response.content_block.text}
|
|
||||||
elif isinstance(response.content_block, ThinkingBlock):
|
|
||||||
current_block = ThinkingBlockParam(
|
|
||||||
type="thinking",
|
|
||||||
thinking=response.content_block.thinking,
|
|
||||||
signature=response.content_block.signature,
|
|
||||||
)
|
|
||||||
elif isinstance(response.content_block, RedactedThinkingBlock):
|
|
||||||
current_block = RedactedThinkingBlockParam(
|
|
||||||
type="redacted_thinking", data=response.content_block.data
|
|
||||||
)
|
|
||||||
LOGGER.debug(
|
|
||||||
"Some of Claude’s internal reasoning has been automatically "
|
|
||||||
"encrypted for safety reasons. This doesn’t affect the quality of "
|
|
||||||
"responses"
|
|
||||||
)
|
|
||||||
elif isinstance(response, RawContentBlockDeltaEvent):
|
|
||||||
if current_block is None:
|
|
||||||
raise ValueError("Unexpected delta without a block")
|
|
||||||
if isinstance(response.delta, InputJSONDelta):
|
|
||||||
current_tool_args += response.delta.partial_json
|
|
||||||
elif isinstance(response.delta, TextDelta):
|
|
||||||
text_block = cast(TextBlockParam, current_block)
|
|
||||||
text_block["text"] += response.delta.text
|
|
||||||
yield {"content": response.delta.text}
|
|
||||||
elif isinstance(response.delta, ThinkingDelta):
|
|
||||||
thinking_block = cast(ThinkingBlockParam, current_block)
|
|
||||||
thinking_block["thinking"] += response.delta.thinking
|
|
||||||
elif isinstance(response.delta, SignatureDelta):
|
|
||||||
thinking_block = cast(ThinkingBlockParam, current_block)
|
|
||||||
thinking_block["signature"] += response.delta.signature
|
|
||||||
elif isinstance(response, RawContentBlockStopEvent):
|
|
||||||
if current_block is None:
|
|
||||||
raise ValueError("Unexpected stop event without a current block")
|
|
||||||
if current_block["type"] == "tool_use":
|
|
||||||
# tool block
|
|
||||||
tool_args = json.loads(current_tool_args) if current_tool_args else {}
|
|
||||||
current_block["input"] = tool_args
|
|
||||||
yield {
|
|
||||||
"tool_calls": [
|
|
||||||
llm.ToolInput(
|
|
||||||
id=current_block["id"],
|
|
||||||
tool_name=current_block["name"],
|
|
||||||
tool_args=tool_args,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
elif current_block["type"] == "thinking":
|
|
||||||
# thinking block
|
|
||||||
LOGGER.debug("Thinking: %s", current_block["thinking"])
|
|
||||||
|
|
||||||
if current_message is None:
|
|
||||||
raise ValueError("Unexpected stop event without a current message")
|
|
||||||
current_message["content"].append(current_block) # type: ignore[union-attr]
|
|
||||||
current_block = None
|
|
||||||
elif isinstance(response, RawMessageDeltaEvent):
|
|
||||||
if (usage := response.usage) is not None:
|
|
||||||
chat_log.async_trace(_create_token_stats(input_usage, usage))
|
|
||||||
if response.delta.stop_reason == "refusal":
|
|
||||||
raise HomeAssistantError("Potential policy violation detected")
|
|
||||||
elif isinstance(response, RawMessageStopEvent):
|
|
||||||
if current_message is not None:
|
|
||||||
messages.append(current_message)
|
|
||||||
current_message = None
|
|
||||||
|
|
||||||
|
|
||||||
def _create_token_stats(
|
|
||||||
input_usage: Usage | None, response_usage: MessageDeltaUsage
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create token stats for conversation agent tracing."""
|
|
||||||
input_tokens = 0
|
|
||||||
cached_input_tokens = 0
|
|
||||||
if input_usage:
|
|
||||||
input_tokens = input_usage.input_tokens
|
|
||||||
cached_input_tokens = input_usage.cache_creation_input_tokens or 0
|
|
||||||
output_tokens = response_usage.output_tokens
|
|
||||||
return {
|
|
||||||
"stats": {
|
|
||||||
"input_tokens": input_tokens,
|
|
||||||
"cached_input_tokens": cached_input_tokens,
|
|
||||||
"output_tokens": output_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConversationEntity(
|
class AnthropicConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity,
|
||||||
|
conversation.AbstractConversationAgent,
|
||||||
|
AnthropicBaseLLMEntity,
|
||||||
):
|
):
|
||||||
"""Anthropic conversation agent."""
|
"""Anthropic conversation agent."""
|
||||||
|
|
||||||
@ -336,17 +41,7 @@ class AnthropicConversationEntity(
|
|||||||
|
|
||||||
def __init__(self, entry: AnthropicConfigEntry, subentry: ConfigSubentry) -> None:
|
def __init__(self, entry: AnthropicConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
super().__init__(entry, subentry)
|
||||||
self.subentry = subentry
|
|
||||||
self._attr_name = subentry.title
|
|
||||||
self._attr_unique_id = subentry.subentry_id
|
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
|
||||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
|
||||||
name=subentry.title,
|
|
||||||
manufacturer="Anthropic",
|
|
||||||
model="Claude",
|
|
||||||
entry_type=dr.DeviceEntryType.SERVICE,
|
|
||||||
)
|
|
||||||
if self.subentry.data.get(CONF_LLM_HASS_API):
|
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||||
self._attr_supported_features = (
|
self._attr_supported_features = (
|
||||||
conversation.ConversationEntityFeature.CONTROL
|
conversation.ConversationEntityFeature.CONTROL
|
||||||
@ -395,73 +90,6 @@ class AnthropicConversationEntity(
|
|||||||
continue_conversation=chat_log.continue_conversation,
|
continue_conversation=chat_log.continue_conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_handle_chat_log(
|
|
||||||
self,
|
|
||||||
chat_log: conversation.ChatLog,
|
|
||||||
) -> None:
|
|
||||||
"""Generate an answer for the chat log."""
|
|
||||||
options = self.subentry.data
|
|
||||||
|
|
||||||
tools: list[ToolParam] | None = None
|
|
||||||
if chat_log.llm_api:
|
|
||||||
tools = [
|
|
||||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
|
||||||
for tool in chat_log.llm_api.tools
|
|
||||||
]
|
|
||||||
|
|
||||||
system = chat_log.content[0]
|
|
||||||
if not isinstance(system, conversation.SystemContent):
|
|
||||||
raise TypeError("First message must be a system message")
|
|
||||||
messages = _convert_content(chat_log.content[1:])
|
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
|
||||||
|
|
||||||
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
|
|
||||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
|
||||||
model_args = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"tools": tools or NOT_GIVEN,
|
|
||||||
"max_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
|
||||||
"system": system.content,
|
|
||||||
"stream": True,
|
|
||||||
}
|
|
||||||
if model in THINKING_MODELS and thinking_budget >= MIN_THINKING_BUDGET:
|
|
||||||
model_args["thinking"] = ThinkingConfigEnabledParam(
|
|
||||||
type="enabled", budget_tokens=thinking_budget
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
|
|
||||||
model_args["temperature"] = options.get(
|
|
||||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
stream = await client.messages.create(**model_args)
|
|
||||||
except anthropic.AnthropicError as err:
|
|
||||||
raise HomeAssistantError(
|
|
||||||
f"Sorry, I had a problem talking to Anthropic: {err}"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
messages.extend(
|
|
||||||
_convert_content(
|
|
||||||
[
|
|
||||||
content
|
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
|
||||||
self.entity_id,
|
|
||||||
_transform_stream(chat_log, stream, messages),
|
|
||||||
)
|
|
||||||
if not isinstance(content, conversation.AssistantContent)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not chat_log.unresponded_tool_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
async def _async_entry_update_listener(
|
async def _async_entry_update_listener(
|
||||||
self, hass: HomeAssistant, entry: ConfigEntry
|
self, hass: HomeAssistant, entry: ConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
|
393
homeassistant/components/anthropic/entity.py
Normal file
393
homeassistant/components/anthropic/entity.py
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
"""Base entity for Anthropic."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator, Callable, Iterable
|
||||||
|
import json
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
from anthropic import AsyncStream
|
||||||
|
from anthropic._types import NOT_GIVEN
|
||||||
|
from anthropic.types import (
|
||||||
|
InputJSONDelta,
|
||||||
|
MessageDeltaUsage,
|
||||||
|
MessageParam,
|
||||||
|
MessageStreamEvent,
|
||||||
|
RawContentBlockDeltaEvent,
|
||||||
|
RawContentBlockStartEvent,
|
||||||
|
RawContentBlockStopEvent,
|
||||||
|
RawMessageDeltaEvent,
|
||||||
|
RawMessageStartEvent,
|
||||||
|
RawMessageStopEvent,
|
||||||
|
RedactedThinkingBlock,
|
||||||
|
RedactedThinkingBlockParam,
|
||||||
|
SignatureDelta,
|
||||||
|
TextBlock,
|
||||||
|
TextBlockParam,
|
||||||
|
TextDelta,
|
||||||
|
ThinkingBlock,
|
||||||
|
ThinkingBlockParam,
|
||||||
|
ThinkingConfigDisabledParam,
|
||||||
|
ThinkingConfigEnabledParam,
|
||||||
|
ThinkingDelta,
|
||||||
|
ToolParam,
|
||||||
|
ToolResultBlockParam,
|
||||||
|
ToolUseBlock,
|
||||||
|
ToolUseBlockParam,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import device_registry as dr, llm
|
||||||
|
from homeassistant.helpers.entity import Entity
|
||||||
|
|
||||||
|
from . import AnthropicConfigEntry
|
||||||
|
from .const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_MAX_TOKENS,
|
||||||
|
CONF_TEMPERATURE,
|
||||||
|
CONF_THINKING_BUDGET,
|
||||||
|
DOMAIN,
|
||||||
|
LOGGER,
|
||||||
|
MIN_THINKING_BUDGET,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_THINKING_BUDGET,
|
||||||
|
THINKING_MODELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Max number of back and forth with the LLM to generate a response
|
||||||
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool(
|
||||||
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||||
|
) -> ToolParam:
|
||||||
|
"""Format tool specification."""
|
||||||
|
return ToolParam(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description or "",
|
||||||
|
input_schema=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_content(
|
||||||
|
chat_content: Iterable[conversation.Content],
|
||||||
|
) -> list[MessageParam]:
|
||||||
|
"""Transform HA chat_log content into Anthropic API format."""
|
||||||
|
messages: list[MessageParam] = []
|
||||||
|
|
||||||
|
for content in chat_content:
|
||||||
|
if isinstance(content, conversation.ToolResultContent):
|
||||||
|
tool_result_block = ToolResultBlockParam(
|
||||||
|
type="tool_result",
|
||||||
|
tool_use_id=content.tool_call_id,
|
||||||
|
content=json.dumps(content.tool_result),
|
||||||
|
)
|
||||||
|
if not messages or messages[-1]["role"] != "user":
|
||||||
|
messages.append(
|
||||||
|
MessageParam(
|
||||||
|
role="user",
|
||||||
|
content=[tool_result_block],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(messages[-1]["content"], str):
|
||||||
|
messages[-1]["content"] = [
|
||||||
|
TextBlockParam(type="text", text=messages[-1]["content"]),
|
||||||
|
tool_result_block,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages[-1]["content"].append(tool_result_block) # type: ignore[attr-defined]
|
||||||
|
elif isinstance(content, conversation.UserContent):
|
||||||
|
# Combine consequent user messages
|
||||||
|
if not messages or messages[-1]["role"] != "user":
|
||||||
|
messages.append(
|
||||||
|
MessageParam(
|
||||||
|
role="user",
|
||||||
|
content=content.content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(messages[-1]["content"], str):
|
||||||
|
messages[-1]["content"] = [
|
||||||
|
TextBlockParam(type="text", text=messages[-1]["content"]),
|
||||||
|
TextBlockParam(type="text", text=content.content),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages[-1]["content"].append( # type: ignore[attr-defined]
|
||||||
|
TextBlockParam(type="text", text=content.content)
|
||||||
|
)
|
||||||
|
elif isinstance(content, conversation.AssistantContent):
|
||||||
|
# Combine consequent assistant messages
|
||||||
|
if not messages or messages[-1]["role"] != "assistant":
|
||||||
|
messages.append(
|
||||||
|
MessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if content.content:
|
||||||
|
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||||
|
TextBlockParam(type="text", text=content.content)
|
||||||
|
)
|
||||||
|
if content.tool_calls:
|
||||||
|
messages[-1]["content"].extend( # type: ignore[union-attr]
|
||||||
|
[
|
||||||
|
ToolUseBlockParam(
|
||||||
|
type="tool_use",
|
||||||
|
id=tool_call.id,
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
input=tool_call.tool_args,
|
||||||
|
)
|
||||||
|
for tool_call in content.tool_calls
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Note: We don't pass SystemContent here as its passed to the API as the prompt
|
||||||
|
raise TypeError(f"Unexpected content type: {type(content)}")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
result: AsyncStream[MessageStreamEvent],
|
||||||
|
messages: list[MessageParam],
|
||||||
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
|
"""Transform the response stream into HA format.
|
||||||
|
|
||||||
|
A typical stream of responses might look something like the following:
|
||||||
|
- RawMessageStartEvent with no content
|
||||||
|
- RawContentBlockStartEvent with an empty ThinkingBlock (if extended thinking is enabled)
|
||||||
|
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||||||
|
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||||||
|
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||||||
|
- ...
|
||||||
|
- RawContentBlockDeltaEvent with a SignatureDelta
|
||||||
|
- RawContentBlockStopEvent
|
||||||
|
- RawContentBlockStartEvent with a RedactedThinkingBlock (occasionally)
|
||||||
|
- RawContentBlockStopEvent (RedactedThinkingBlock does not have a delta)
|
||||||
|
- RawContentBlockStartEvent with an empty TextBlock
|
||||||
|
- RawContentBlockDeltaEvent with a TextDelta
|
||||||
|
- RawContentBlockDeltaEvent with a TextDelta
|
||||||
|
- RawContentBlockDeltaEvent with a TextDelta
|
||||||
|
- ...
|
||||||
|
- RawContentBlockStopEvent
|
||||||
|
- RawContentBlockStartEvent with ToolUseBlock specifying the function name
|
||||||
|
- RawContentBlockDeltaEvent with a InputJSONDelta
|
||||||
|
- RawContentBlockDeltaEvent with a InputJSONDelta
|
||||||
|
- ...
|
||||||
|
- RawContentBlockStopEvent
|
||||||
|
- RawMessageDeltaEvent with a stop_reason='tool_use'
|
||||||
|
- RawMessageStopEvent(type='message_stop')
|
||||||
|
|
||||||
|
Each message could contain multiple blocks of the same type.
|
||||||
|
"""
|
||||||
|
if result is None:
|
||||||
|
raise TypeError("Expected a stream of messages")
|
||||||
|
|
||||||
|
current_message: MessageParam | None = None
|
||||||
|
current_block: (
|
||||||
|
TextBlockParam
|
||||||
|
| ToolUseBlockParam
|
||||||
|
| ThinkingBlockParam
|
||||||
|
| RedactedThinkingBlockParam
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
|
current_tool_args: str
|
||||||
|
input_usage: Usage | None = None
|
||||||
|
|
||||||
|
async for response in result:
|
||||||
|
LOGGER.debug("Received response: %s", response)
|
||||||
|
|
||||||
|
if isinstance(response, RawMessageStartEvent):
|
||||||
|
if response.message.role != "assistant":
|
||||||
|
raise ValueError("Unexpected message role")
|
||||||
|
current_message = MessageParam(role=response.message.role, content=[])
|
||||||
|
input_usage = response.message.usage
|
||||||
|
elif isinstance(response, RawContentBlockStartEvent):
|
||||||
|
if isinstance(response.content_block, ToolUseBlock):
|
||||||
|
current_block = ToolUseBlockParam(
|
||||||
|
type="tool_use",
|
||||||
|
id=response.content_block.id,
|
||||||
|
name=response.content_block.name,
|
||||||
|
input="",
|
||||||
|
)
|
||||||
|
current_tool_args = ""
|
||||||
|
elif isinstance(response.content_block, TextBlock):
|
||||||
|
current_block = TextBlockParam(
|
||||||
|
type="text", text=response.content_block.text
|
||||||
|
)
|
||||||
|
yield {"role": "assistant"}
|
||||||
|
if response.content_block.text:
|
||||||
|
yield {"content": response.content_block.text}
|
||||||
|
elif isinstance(response.content_block, ThinkingBlock):
|
||||||
|
current_block = ThinkingBlockParam(
|
||||||
|
type="thinking",
|
||||||
|
thinking=response.content_block.thinking,
|
||||||
|
signature=response.content_block.signature,
|
||||||
|
)
|
||||||
|
elif isinstance(response.content_block, RedactedThinkingBlock):
|
||||||
|
current_block = RedactedThinkingBlockParam(
|
||||||
|
type="redacted_thinking", data=response.content_block.data
|
||||||
|
)
|
||||||
|
LOGGER.debug(
|
||||||
|
"Some of Claude’s internal reasoning has been automatically "
|
||||||
|
"encrypted for safety reasons. This doesn’t affect the quality of "
|
||||||
|
"responses"
|
||||||
|
)
|
||||||
|
elif isinstance(response, RawContentBlockDeltaEvent):
|
||||||
|
if current_block is None:
|
||||||
|
raise ValueError("Unexpected delta without a block")
|
||||||
|
if isinstance(response.delta, InputJSONDelta):
|
||||||
|
current_tool_args += response.delta.partial_json
|
||||||
|
elif isinstance(response.delta, TextDelta):
|
||||||
|
text_block = cast(TextBlockParam, current_block)
|
||||||
|
text_block["text"] += response.delta.text
|
||||||
|
yield {"content": response.delta.text}
|
||||||
|
elif isinstance(response.delta, ThinkingDelta):
|
||||||
|
thinking_block = cast(ThinkingBlockParam, current_block)
|
||||||
|
thinking_block["thinking"] += response.delta.thinking
|
||||||
|
elif isinstance(response.delta, SignatureDelta):
|
||||||
|
thinking_block = cast(ThinkingBlockParam, current_block)
|
||||||
|
thinking_block["signature"] += response.delta.signature
|
||||||
|
elif isinstance(response, RawContentBlockStopEvent):
|
||||||
|
if current_block is None:
|
||||||
|
raise ValueError("Unexpected stop event without a current block")
|
||||||
|
if current_block["type"] == "tool_use":
|
||||||
|
# tool block
|
||||||
|
tool_args = json.loads(current_tool_args) if current_tool_args else {}
|
||||||
|
current_block["input"] = tool_args
|
||||||
|
yield {
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id=current_block["id"],
|
||||||
|
tool_name=current_block["name"],
|
||||||
|
tool_args=tool_args,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
elif current_block["type"] == "thinking":
|
||||||
|
# thinking block
|
||||||
|
LOGGER.debug("Thinking: %s", current_block["thinking"])
|
||||||
|
|
||||||
|
if current_message is None:
|
||||||
|
raise ValueError("Unexpected stop event without a current message")
|
||||||
|
current_message["content"].append(current_block) # type: ignore[union-attr]
|
||||||
|
current_block = None
|
||||||
|
elif isinstance(response, RawMessageDeltaEvent):
|
||||||
|
if (usage := response.usage) is not None:
|
||||||
|
chat_log.async_trace(_create_token_stats(input_usage, usage))
|
||||||
|
if response.delta.stop_reason == "refusal":
|
||||||
|
raise HomeAssistantError("Potential policy violation detected")
|
||||||
|
elif isinstance(response, RawMessageStopEvent):
|
||||||
|
if current_message is not None:
|
||||||
|
messages.append(current_message)
|
||||||
|
current_message = None
|
||||||
|
|
||||||
|
|
||||||
|
def _create_token_stats(
|
||||||
|
input_usage: Usage | None, response_usage: MessageDeltaUsage
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create token stats for conversation agent tracing."""
|
||||||
|
input_tokens = 0
|
||||||
|
cached_input_tokens = 0
|
||||||
|
if input_usage:
|
||||||
|
input_tokens = input_usage.input_tokens
|
||||||
|
cached_input_tokens = input_usage.cache_creation_input_tokens or 0
|
||||||
|
output_tokens = response_usage.output_tokens
|
||||||
|
return {
|
||||||
|
"stats": {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"cached_input_tokens": cached_input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicBaseLLMEntity(Entity):
|
||||||
|
"""Anthropic base LLM entity."""
|
||||||
|
|
||||||
|
def __init__(self, entry: AnthropicConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
|
"""Initialize the entity."""
|
||||||
|
self.entry = entry
|
||||||
|
self.subentry = subentry
|
||||||
|
self._attr_name = subentry.title
|
||||||
|
self._attr_unique_id = subentry.subentry_id
|
||||||
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
|
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||||
|
name=subentry.title,
|
||||||
|
manufacturer="Anthropic",
|
||||||
|
model="Claude",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _async_handle_chat_log(
|
||||||
|
self,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an answer for the chat log."""
|
||||||
|
options = self.subentry.data
|
||||||
|
|
||||||
|
tools: list[ToolParam] | None = None
|
||||||
|
if chat_log.llm_api:
|
||||||
|
tools = [
|
||||||
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
|
for tool in chat_log.llm_api.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
system = chat_log.content[0]
|
||||||
|
if not isinstance(system, conversation.SystemContent):
|
||||||
|
raise TypeError("First message must be a system message")
|
||||||
|
messages = _convert_content(chat_log.content[1:])
|
||||||
|
|
||||||
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
|
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
|
||||||
|
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
|
|
||||||
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
|
model_args = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools or NOT_GIVEN,
|
||||||
|
"max_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
|
"system": system.content,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
if model in THINKING_MODELS and thinking_budget >= MIN_THINKING_BUDGET:
|
||||||
|
model_args["thinking"] = ThinkingConfigEnabledParam(
|
||||||
|
type="enabled", budget_tokens=thinking_budget
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
|
||||||
|
model_args["temperature"] = options.get(
|
||||||
|
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await client.messages.create(**model_args)
|
||||||
|
except anthropic.AnthropicError as err:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
f"Sorry, I had a problem talking to Anthropic: {err}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
messages.extend(
|
||||||
|
_convert_content(
|
||||||
|
[
|
||||||
|
content
|
||||||
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
|
self.entity_id,
|
||||||
|
_transform_stream(chat_log, stream, messages),
|
||||||
|
)
|
||||||
|
if not isinstance(content, conversation.AssistantContent)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chat_log.unresponded_tool_results:
|
||||||
|
break
|
@ -316,7 +316,7 @@ async def test_conversation_agent(
|
|||||||
assert agent.supported_languages == "*"
|
assert agent.supported_languages == "*"
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools")
|
@patch("homeassistant.components.anthropic.entity.llm.AssistAPI._async_get_tools")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("tool_call_json_parts", "expected_call_tool_args"),
|
("tool_call_json_parts", "expected_call_tool_args"),
|
||||||
[
|
[
|
||||||
@ -430,7 +430,7 @@ async def test_function_call(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools")
|
@patch("homeassistant.components.anthropic.entity.llm.AssistAPI._async_get_tools")
|
||||||
async def test_function_exception(
|
async def test_function_exception(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -760,7 +760,7 @@ async def test_redacted_thinking(
|
|||||||
assert chat_log.content[2].content == "How can I help you today?"
|
assert chat_log.content[2].content == "How can I help you today?"
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools")
|
@patch("homeassistant.components.anthropic.entity.llm.AssistAPI._async_get_tools")
|
||||||
async def test_extended_thinking_tool_call(
|
async def test_extended_thinking_tool_call(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user