mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 02:07:54 +00:00
423 lines
16 KiB
Python
423 lines
16 KiB
Python
"""Conversation support for Anthropic."""
|
||
|
||
from collections.abc import AsyncGenerator, Callable, Iterable
|
||
import json
|
||
from typing import Any, Literal, cast
|
||
|
||
import anthropic
|
||
from anthropic import AsyncStream
|
||
from anthropic._types import NOT_GIVEN
|
||
from anthropic.types import (
|
||
InputJSONDelta,
|
||
MessageParam,
|
||
MessageStreamEvent,
|
||
RawContentBlockDeltaEvent,
|
||
RawContentBlockStartEvent,
|
||
RawContentBlockStopEvent,
|
||
RawMessageStartEvent,
|
||
RawMessageStopEvent,
|
||
RedactedThinkingBlock,
|
||
RedactedThinkingBlockParam,
|
||
SignatureDelta,
|
||
TextBlock,
|
||
TextBlockParam,
|
||
TextDelta,
|
||
ThinkingBlock,
|
||
ThinkingBlockParam,
|
||
ThinkingConfigDisabledParam,
|
||
ThinkingConfigEnabledParam,
|
||
ThinkingDelta,
|
||
ToolParam,
|
||
ToolResultBlockParam,
|
||
ToolUseBlock,
|
||
ToolUseBlockParam,
|
||
)
|
||
from voluptuous_openapi import convert
|
||
|
||
from homeassistant.components import conversation
|
||
from homeassistant.config_entries import ConfigEntry
|
||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||
from homeassistant.core import HomeAssistant
|
||
from homeassistant.exceptions import HomeAssistantError
|
||
from homeassistant.helpers import device_registry as dr, intent, llm
|
||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||
|
||
from . import AnthropicConfigEntry
|
||
from .const import (
|
||
CONF_CHAT_MODEL,
|
||
CONF_MAX_TOKENS,
|
||
CONF_PROMPT,
|
||
CONF_TEMPERATURE,
|
||
CONF_THINKING_BUDGET,
|
||
DOMAIN,
|
||
LOGGER,
|
||
MIN_THINKING_BUDGET,
|
||
RECOMMENDED_CHAT_MODEL,
|
||
RECOMMENDED_MAX_TOKENS,
|
||
RECOMMENDED_TEMPERATURE,
|
||
RECOMMENDED_THINKING_BUDGET,
|
||
THINKING_MODELS,
|
||
)
|
||
|
||
# Max number of back and forth with the LLM to generate a response
|
||
MAX_TOOL_ITERATIONS = 10
|
||
|
||
|
||
async def async_setup_entry(
|
||
hass: HomeAssistant,
|
||
config_entry: AnthropicConfigEntry,
|
||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||
) -> None:
|
||
"""Set up conversation entities."""
|
||
agent = AnthropicConversationEntity(config_entry)
|
||
async_add_entities([agent])
|
||
|
||
|
||
def _format_tool(
|
||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||
) -> ToolParam:
|
||
"""Format tool specification."""
|
||
return ToolParam(
|
||
name=tool.name,
|
||
description=tool.description or "",
|
||
input_schema=convert(tool.parameters, custom_serializer=custom_serializer),
|
||
)
|
||
|
||
|
||
def _convert_content(
|
||
chat_content: Iterable[conversation.Content],
|
||
) -> list[MessageParam]:
|
||
"""Transform HA chat_log content into Anthropic API format."""
|
||
messages: list[MessageParam] = []
|
||
|
||
for content in chat_content:
|
||
if isinstance(content, conversation.ToolResultContent):
|
||
tool_result_block = ToolResultBlockParam(
|
||
type="tool_result",
|
||
tool_use_id=content.tool_call_id,
|
||
content=json.dumps(content.tool_result),
|
||
)
|
||
if not messages or messages[-1]["role"] != "user":
|
||
messages.append(
|
||
MessageParam(
|
||
role="user",
|
||
content=[tool_result_block],
|
||
)
|
||
)
|
||
elif isinstance(messages[-1]["content"], str):
|
||
messages[-1]["content"] = [
|
||
TextBlockParam(type="text", text=messages[-1]["content"]),
|
||
tool_result_block,
|
||
]
|
||
else:
|
||
messages[-1]["content"].append(tool_result_block) # type: ignore[attr-defined]
|
||
elif isinstance(content, conversation.UserContent):
|
||
# Combine consequent user messages
|
||
if not messages or messages[-1]["role"] != "user":
|
||
messages.append(
|
||
MessageParam(
|
||
role="user",
|
||
content=content.content,
|
||
)
|
||
)
|
||
elif isinstance(messages[-1]["content"], str):
|
||
messages[-1]["content"] = [
|
||
TextBlockParam(type="text", text=messages[-1]["content"]),
|
||
TextBlockParam(type="text", text=content.content),
|
||
]
|
||
else:
|
||
messages[-1]["content"].append( # type: ignore[attr-defined]
|
||
TextBlockParam(type="text", text=content.content)
|
||
)
|
||
elif isinstance(content, conversation.AssistantContent):
|
||
# Combine consequent assistant messages
|
||
if not messages or messages[-1]["role"] != "assistant":
|
||
messages.append(
|
||
MessageParam(
|
||
role="assistant",
|
||
content=[],
|
||
)
|
||
)
|
||
|
||
if content.content:
|
||
messages[-1]["content"].append( # type: ignore[union-attr]
|
||
TextBlockParam(type="text", text=content.content)
|
||
)
|
||
if content.tool_calls:
|
||
messages[-1]["content"].extend( # type: ignore[union-attr]
|
||
[
|
||
ToolUseBlockParam(
|
||
type="tool_use",
|
||
id=tool_call.id,
|
||
name=tool_call.tool_name,
|
||
input=tool_call.tool_args,
|
||
)
|
||
for tool_call in content.tool_calls
|
||
]
|
||
)
|
||
else:
|
||
# Note: We don't pass SystemContent here as its passed to the API as the prompt
|
||
raise TypeError(f"Unexpected content type: {type(content)}")
|
||
|
||
return messages
|
||
|
||
|
||
async def _transform_stream(
|
||
result: AsyncStream[MessageStreamEvent],
|
||
messages: list[MessageParam],
|
||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||
"""Transform the response stream into HA format.
|
||
|
||
A typical stream of responses might look something like the following:
|
||
- RawMessageStartEvent with no content
|
||
- RawContentBlockStartEvent with an empty ThinkingBlock (if extended thinking is enabled)
|
||
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||
- RawContentBlockDeltaEvent with a ThinkingDelta
|
||
- ...
|
||
- RawContentBlockDeltaEvent with a SignatureDelta
|
||
- RawContentBlockStopEvent
|
||
- RawContentBlockStartEvent with a RedactedThinkingBlock (occasionally)
|
||
- RawContentBlockStopEvent (RedactedThinkingBlock does not have a delta)
|
||
- RawContentBlockStartEvent with an empty TextBlock
|
||
- RawContentBlockDeltaEvent with a TextDelta
|
||
- RawContentBlockDeltaEvent with a TextDelta
|
||
- RawContentBlockDeltaEvent with a TextDelta
|
||
- ...
|
||
- RawContentBlockStopEvent
|
||
- RawContentBlockStartEvent with ToolUseBlock specifying the function name
|
||
- RawContentBlockDeltaEvent with a InputJSONDelta
|
||
- RawContentBlockDeltaEvent with a InputJSONDelta
|
||
- ...
|
||
- RawContentBlockStopEvent
|
||
- RawMessageDeltaEvent with a stop_reason='tool_use'
|
||
- RawMessageStopEvent(type='message_stop')
|
||
|
||
Each message could contain multiple blocks of the same type.
|
||
"""
|
||
if result is None:
|
||
raise TypeError("Expected a stream of messages")
|
||
|
||
current_message: MessageParam | None = None
|
||
current_block: (
|
||
TextBlockParam
|
||
| ToolUseBlockParam
|
||
| ThinkingBlockParam
|
||
| RedactedThinkingBlockParam
|
||
| None
|
||
) = None
|
||
current_tool_args: str
|
||
|
||
async for response in result:
|
||
LOGGER.debug("Received response: %s", response)
|
||
|
||
if isinstance(response, RawMessageStartEvent):
|
||
if response.message.role != "assistant":
|
||
raise ValueError("Unexpected message role")
|
||
current_message = MessageParam(role=response.message.role, content=[])
|
||
elif isinstance(response, RawContentBlockStartEvent):
|
||
if isinstance(response.content_block, ToolUseBlock):
|
||
current_block = ToolUseBlockParam(
|
||
type="tool_use",
|
||
id=response.content_block.id,
|
||
name=response.content_block.name,
|
||
input="",
|
||
)
|
||
current_tool_args = ""
|
||
elif isinstance(response.content_block, TextBlock):
|
||
current_block = TextBlockParam(
|
||
type="text", text=response.content_block.text
|
||
)
|
||
yield {"role": "assistant"}
|
||
if response.content_block.text:
|
||
yield {"content": response.content_block.text}
|
||
elif isinstance(response.content_block, ThinkingBlock):
|
||
current_block = ThinkingBlockParam(
|
||
type="thinking",
|
||
thinking=response.content_block.thinking,
|
||
signature=response.content_block.signature,
|
||
)
|
||
elif isinstance(response.content_block, RedactedThinkingBlock):
|
||
current_block = RedactedThinkingBlockParam(
|
||
type="redacted_thinking", data=response.content_block.data
|
||
)
|
||
LOGGER.debug(
|
||
"Some of Claude’s internal reasoning has been automatically "
|
||
"encrypted for safety reasons. This doesn’t affect the quality of "
|
||
"responses"
|
||
)
|
||
elif isinstance(response, RawContentBlockDeltaEvent):
|
||
if current_block is None:
|
||
raise ValueError("Unexpected delta without a block")
|
||
if isinstance(response.delta, InputJSONDelta):
|
||
current_tool_args += response.delta.partial_json
|
||
elif isinstance(response.delta, TextDelta):
|
||
text_block = cast(TextBlockParam, current_block)
|
||
text_block["text"] += response.delta.text
|
||
yield {"content": response.delta.text}
|
||
elif isinstance(response.delta, ThinkingDelta):
|
||
thinking_block = cast(ThinkingBlockParam, current_block)
|
||
thinking_block["thinking"] += response.delta.thinking
|
||
elif isinstance(response.delta, SignatureDelta):
|
||
thinking_block = cast(ThinkingBlockParam, current_block)
|
||
thinking_block["signature"] += response.delta.signature
|
||
elif isinstance(response, RawContentBlockStopEvent):
|
||
if current_block is None:
|
||
raise ValueError("Unexpected stop event without a current block")
|
||
if current_block["type"] == "tool_use":
|
||
tool_block = cast(ToolUseBlockParam, current_block)
|
||
tool_args = json.loads(current_tool_args) if current_tool_args else {}
|
||
tool_block["input"] = tool_args
|
||
yield {
|
||
"tool_calls": [
|
||
llm.ToolInput(
|
||
id=tool_block["id"],
|
||
tool_name=tool_block["name"],
|
||
tool_args=tool_args,
|
||
)
|
||
]
|
||
}
|
||
elif current_block["type"] == "thinking":
|
||
thinking_block = cast(ThinkingBlockParam, current_block)
|
||
LOGGER.debug("Thinking: %s", thinking_block["thinking"])
|
||
|
||
if current_message is None:
|
||
raise ValueError("Unexpected stop event without a current message")
|
||
current_message["content"].append(current_block) # type: ignore[union-attr]
|
||
current_block = None
|
||
elif isinstance(response, RawMessageStopEvent):
|
||
if current_message is not None:
|
||
messages.append(current_message)
|
||
current_message = None
|
||
|
||
|
||
class AnthropicConversationEntity(
|
||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||
):
|
||
"""Anthropic conversation agent."""
|
||
|
||
_attr_has_entity_name = True
|
||
_attr_name = None
|
||
|
||
def __init__(self, entry: AnthropicConfigEntry) -> None:
|
||
"""Initialize the agent."""
|
||
self.entry = entry
|
||
self._attr_unique_id = entry.entry_id
|
||
self._attr_device_info = dr.DeviceInfo(
|
||
identifiers={(DOMAIN, entry.entry_id)},
|
||
manufacturer="Anthropic",
|
||
model="Claude",
|
||
entry_type=dr.DeviceEntryType.SERVICE,
|
||
)
|
||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||
self._attr_supported_features = (
|
||
conversation.ConversationEntityFeature.CONTROL
|
||
)
|
||
|
||
@property
|
||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||
"""Return a list of supported languages."""
|
||
return MATCH_ALL
|
||
|
||
async def async_added_to_hass(self) -> None:
|
||
"""When entity is added to Home Assistant."""
|
||
await super().async_added_to_hass()
|
||
self.entry.async_on_unload(
|
||
self.entry.add_update_listener(self._async_entry_update_listener)
|
||
)
|
||
|
||
async def _async_handle_message(
|
||
self,
|
||
user_input: conversation.ConversationInput,
|
||
chat_log: conversation.ChatLog,
|
||
) -> conversation.ConversationResult:
|
||
"""Call the API."""
|
||
options = self.entry.options
|
||
|
||
try:
|
||
await chat_log.async_update_llm_data(
|
||
DOMAIN,
|
||
user_input,
|
||
options.get(CONF_LLM_HASS_API),
|
||
options.get(CONF_PROMPT),
|
||
)
|
||
except conversation.ConverseError as err:
|
||
return err.as_conversation_result()
|
||
|
||
tools: list[ToolParam] | None = None
|
||
if chat_log.llm_api:
|
||
tools = [
|
||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||
for tool in chat_log.llm_api.tools
|
||
]
|
||
|
||
system = chat_log.content[0]
|
||
if not isinstance(system, conversation.SystemContent):
|
||
raise TypeError("First message must be a system message")
|
||
messages = _convert_content(chat_log.content[1:])
|
||
|
||
client = self.entry.runtime_data
|
||
|
||
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
|
||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||
|
||
# To prevent infinite loops, we limit the number of iterations
|
||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||
model_args = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"tools": tools or NOT_GIVEN,
|
||
"max_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||
"system": system.content,
|
||
"stream": True,
|
||
}
|
||
if model in THINKING_MODELS and thinking_budget >= MIN_THINKING_BUDGET:
|
||
model_args["thinking"] = ThinkingConfigEnabledParam(
|
||
type="enabled", budget_tokens=thinking_budget
|
||
)
|
||
else:
|
||
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
|
||
model_args["temperature"] = options.get(
|
||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||
)
|
||
|
||
try:
|
||
stream = await client.messages.create(**model_args)
|
||
except anthropic.AnthropicError as err:
|
||
raise HomeAssistantError(
|
||
f"Sorry, I had a problem talking to Anthropic: {err}"
|
||
) from err
|
||
|
||
messages.extend(
|
||
_convert_content(
|
||
[
|
||
content
|
||
async for content in chat_log.async_add_delta_content_stream(
|
||
user_input.agent_id, _transform_stream(stream, messages)
|
||
)
|
||
if not isinstance(content, conversation.AssistantContent)
|
||
]
|
||
)
|
||
)
|
||
|
||
if not chat_log.unresponded_tool_results:
|
||
break
|
||
|
||
response_content = chat_log.content[-1]
|
||
if not isinstance(response_content, conversation.AssistantContent):
|
||
raise TypeError("Last message must be an assistant message")
|
||
intent_response = intent.IntentResponse(language=user_input.language)
|
||
intent_response.async_set_speech(response_content.content or "")
|
||
return conversation.ConversationResult(
|
||
response=intent_response,
|
||
conversation_id=chat_log.conversation_id,
|
||
continue_conversation=chat_log.continue_conversation,
|
||
)
|
||
|
||
async def _async_entry_update_listener(
|
||
self, hass: HomeAssistant, entry: ConfigEntry
|
||
) -> None:
|
||
"""Handle options update."""
|
||
# Reload as we update device info + entity name + supported features
|
||
await hass.config_entries.async_reload(entry.entry_id)
|