mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 22:27:07 +00:00
Update anthropic to use the new chatlog API (#138178)
* Update anthropic to use the new chatlog API * Remove conversation id logging * Add back whitespace * Reduce unnecessary diffs * Revert diffs to conversation component * Replace types with union type
This commit is contained in:
parent
29c6a2ec13
commit
ae38f89728
@ -16,18 +16,15 @@ from anthropic.types import (
|
|||||||
ToolUseBlock,
|
ToolUseBlock,
|
||||||
ToolUseBlockParam,
|
ToolUseBlockParam,
|
||||||
)
|
)
|
||||||
import voluptuous as vol
|
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import trace
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm, template
|
from homeassistant.helpers import chat_session, device_registry as dr, intent, llm
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid as ulid_util
|
|
||||||
|
|
||||||
from . import AnthropicConfigEntry
|
from . import AnthropicConfigEntry
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -89,6 +86,44 @@ def _message_convert(
|
|||||||
return MessageParam(role=message.role, content=param_content)
|
return MessageParam(role=message.role, content=param_content)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_content(chat_content: conversation.Content) -> MessageParam:
|
||||||
|
"""Create tool response content."""
|
||||||
|
if isinstance(chat_content, conversation.ToolResultContent):
|
||||||
|
return MessageParam(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ToolResultBlockParam(
|
||||||
|
type="tool_result",
|
||||||
|
tool_use_id=chat_content.tool_call_id,
|
||||||
|
content=json.dumps(chat_content.tool_result),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if isinstance(chat_content, conversation.AssistantContent):
|
||||||
|
return MessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=[
|
||||||
|
TextBlockParam(type="text", text=chat_content.content or ""),
|
||||||
|
*[
|
||||||
|
ToolUseBlockParam(
|
||||||
|
type="tool_use",
|
||||||
|
id=tool_call.id,
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
input=json.dumps(tool_call.tool_args),
|
||||||
|
)
|
||||||
|
for tool_call in chat_content.tool_calls or ()
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if isinstance(chat_content, conversation.UserContent):
|
||||||
|
return MessageParam(
|
||||||
|
role="user",
|
||||||
|
content=chat_content.content,
|
||||||
|
)
|
||||||
|
# Note: We don't pass SystemContent here as its passed to the API as the prompt
|
||||||
|
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
||||||
|
|
||||||
|
|
||||||
class AnthropicConversationEntity(
|
class AnthropicConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
@ -100,7 +135,6 @@ class AnthropicConversationEntity(
|
|||||||
def __init__(self, entry: AnthropicConfigEntry) -> None:
|
def __init__(self, entry: AnthropicConfigEntry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, list[MessageParam]] = {}
|
|
||||||
self._attr_unique_id = entry.entry_id
|
self._attr_unique_id = entry.entry_id
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
identifiers={(DOMAIN, entry.entry_id)},
|
identifiers={(DOMAIN, entry.entry_id)},
|
||||||
@ -129,110 +163,43 @@ class AnthropicConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
options = self.entry.options
|
with (
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
chat_session.async_get_chat_session(
|
||||||
llm_api: llm.APIInstance | None = None
|
self.hass, user_input.conversation_id
|
||||||
tools: list[ToolParam] | None = None
|
) as session,
|
||||||
user_name: str | None = None
|
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||||
llm_context = llm.LLMContext(
|
|
||||||
platform=DOMAIN,
|
|
||||||
context=user_input.context,
|
|
||||||
user_prompt=user_input.text,
|
|
||||||
language=user_input.language,
|
|
||||||
assistant=conversation.DOMAIN,
|
|
||||||
device_id=user_input.device_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if options.get(CONF_LLM_HASS_API):
|
|
||||||
try:
|
|
||||||
llm_api = await llm.async_get_api(
|
|
||||||
self.hass,
|
|
||||||
options[CONF_LLM_HASS_API],
|
|
||||||
llm_context,
|
|
||||||
)
|
|
||||||
except HomeAssistantError as err:
|
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
|
||||||
intent_response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
f"Error preparing LLM API: {err}",
|
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
|
||||||
)
|
|
||||||
tools = [
|
|
||||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
|
||||||
]
|
|
||||||
|
|
||||||
if user_input.conversation_id is None:
|
|
||||||
conversation_id = ulid_util.ulid_now()
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
elif user_input.conversation_id in self.history:
|
|
||||||
conversation_id = user_input.conversation_id
|
|
||||||
messages = self.history[conversation_id]
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
|
||||||
# If an old OLID is passed in, we will generate a new one to indicate
|
|
||||||
# a new conversation was started. If the user picks their own, they
|
|
||||||
# want to track a conversation and we respect it.
|
|
||||||
try:
|
|
||||||
ulid_util.ulid_to_bytes(user_input.conversation_id)
|
|
||||||
conversation_id = ulid_util.ulid_now()
|
|
||||||
except ValueError:
|
|
||||||
conversation_id = user_input.conversation_id
|
|
||||||
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_input.context
|
|
||||||
and user_input.context.user_id
|
|
||||||
and (
|
|
||||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
user_name = user.name
|
return await self._async_handle_message(user_input, chat_log)
|
||||||
|
|
||||||
|
async def _async_handle_message(
|
||||||
|
self,
|
||||||
|
user_input: conversation.ConversationInput,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> conversation.ConversationResult:
|
||||||
|
"""Call the API."""
|
||||||
|
options = self.entry.options
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_parts = [
|
await chat_log.async_update_llm_data(
|
||||||
template.Template(
|
DOMAIN,
|
||||||
llm.BASE_PROMPT
|
user_input,
|
||||||
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
options.get(CONF_LLM_HASS_API),
|
||||||
self.hass,
|
options.get(CONF_PROMPT),
|
||||||
).async_render(
|
)
|
||||||
{
|
except conversation.ConverseError as err:
|
||||||
"ha_name": self.hass.config.location_name,
|
return err.as_conversation_result()
|
||||||
"user_name": user_name,
|
|
||||||
"llm_context": llm_context,
|
tools: list[ToolParam] | None = None
|
||||||
},
|
if chat_log.llm_api:
|
||||||
parse_result=False,
|
tools = [
|
||||||
)
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
|
for tool in chat_log.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
except TemplateError as err:
|
system = chat_log.content[0]
|
||||||
LOGGER.error("Error rendering prompt: %s", err)
|
if not isinstance(system, conversation.SystemContent):
|
||||||
intent_response.async_set_error(
|
raise TypeError("First message must be a system message")
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
messages = [_convert_content(content) for content in chat_log.content[1:]]
|
||||||
f"Sorry, I had a problem with my template: {err}",
|
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if llm_api:
|
|
||||||
prompt_parts.append(llm_api.api_prompt)
|
|
||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
|
||||||
|
|
||||||
# Create a copy of the variable because we attach it to the trace
|
|
||||||
messages = [*messages, MessageParam(role="user", content=user_input.text)]
|
|
||||||
|
|
||||||
LOGGER.debug("Prompt: %s", messages)
|
|
||||||
LOGGER.debug("Tools: %s", tools)
|
|
||||||
trace.async_conversation_trace_append(
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
|
||||||
{"system": prompt, "messages": messages},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
@ -244,69 +211,62 @@ class AnthropicConversationEntity(
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools or NOT_GIVEN,
|
tools=tools or NOT_GIVEN,
|
||||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
system=prompt,
|
system=system.content,
|
||||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
)
|
)
|
||||||
except anthropic.AnthropicError as err:
|
except anthropic.AnthropicError as err:
|
||||||
intent_response.async_set_error(
|
raise HomeAssistantError(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
f"Sorry, I had a problem talking to Anthropic: {err}"
|
||||||
f"Sorry, I had a problem talking to Anthropic: {err}",
|
) from err
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGGER.debug("Response %s", response)
|
LOGGER.debug("Response %s", response)
|
||||||
|
|
||||||
messages.append(_message_convert(response))
|
messages.append(_message_convert(response))
|
||||||
|
|
||||||
if response.stop_reason != "tool_use" or not llm_api:
|
text = "".join(
|
||||||
break
|
[
|
||||||
|
content.text
|
||||||
tool_results: list[ToolResultBlockParam] = []
|
for content in response.content
|
||||||
for tool_call in response.content:
|
if isinstance(content, TextBlock)
|
||||||
if isinstance(tool_call, TextBlock):
|
]
|
||||||
LOGGER.info(tool_call.text)
|
)
|
||||||
|
tool_inputs = [
|
||||||
if not isinstance(tool_call, ToolUseBlock):
|
llm.ToolInput(
|
||||||
continue
|
|
||||||
|
|
||||||
tool_input = llm.ToolInput(
|
|
||||||
id=tool_call.id,
|
id=tool_call.id,
|
||||||
tool_name=tool_call.name,
|
tool_name=tool_call.name,
|
||||||
tool_args=cast(dict[str, Any], tool_call.input),
|
tool_args=cast(dict[str, Any], tool_call.input),
|
||||||
)
|
)
|
||||||
LOGGER.debug(
|
for tool_call in response.content
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
if isinstance(tool_call, ToolUseBlock)
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_results = [
|
||||||
|
ToolResultBlockParam(
|
||||||
|
type="tool_result",
|
||||||
|
tool_use_id=tool_response.tool_call_id,
|
||||||
|
content=json.dumps(tool_response.tool_result),
|
||||||
)
|
)
|
||||||
|
async for tool_response in chat_log.async_add_assistant_content(
|
||||||
try:
|
conversation.AssistantContent(
|
||||||
tool_response = await llm_api.async_call_tool(tool_input)
|
agent_id=user_input.agent_id,
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
content=text,
|
||||||
tool_response = {"error": type(e).__name__}
|
tool_calls=tool_inputs or None,
|
||||||
if str(e):
|
|
||||||
tool_response["error_text"] = str(e)
|
|
||||||
|
|
||||||
LOGGER.debug("Tool response: %s", tool_response)
|
|
||||||
tool_results.append(
|
|
||||||
ToolResultBlockParam(
|
|
||||||
type="tool_result",
|
|
||||||
tool_use_id=tool_call.id,
|
|
||||||
content=json.dumps(tool_response),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
if tool_results:
|
||||||
|
messages.append(MessageParam(role="user", content=tool_results))
|
||||||
|
|
||||||
messages.append(MessageParam(role="user", content=tool_results))
|
if not tool_inputs:
|
||||||
|
|
||||||
self.history[conversation_id] = messages
|
|
||||||
|
|
||||||
for content in response.content:
|
|
||||||
if isinstance(content, TextBlock):
|
|
||||||
intent_response.async_set_speech(content.text)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
response_content = chat_log.content[-1]
|
||||||
|
if not isinstance(response_content, conversation.AssistantContent):
|
||||||
|
raise TypeError("Last message must be an assistant message")
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
intent_response.async_set_speech(response_content.content or "")
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=chat_log.conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_entry_update_listener(
|
async def _async_entry_update_listener(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_unknown_hass_api
|
# name: test_unknown_hass_api
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': '1234',
|
||||||
'response': IntentResponse(
|
'response': IntentResponse(
|
||||||
card=dict({
|
card=dict({
|
||||||
}),
|
}),
|
||||||
@ -20,7 +20,7 @@
|
|||||||
speech=dict({
|
speech=dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
'extra_data': None,
|
'extra_data': None,
|
||||||
'speech': 'Error preparing LLM API: API non-existing not found',
|
'speech': 'Error preparing LLM API',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
speech_slots=dict({
|
speech_slots=dict({
|
||||||
|
@ -10,7 +10,6 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.conversation import trace
|
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@ -250,42 +249,6 @@ async def test_function_call(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test Conversation tracing
|
|
||||||
traces = trace.async_get_traces()
|
|
||||||
assert traces
|
|
||||||
last_trace = traces[-1].as_dict()
|
|
||||||
trace_events = last_trace.get("events", [])
|
|
||||||
assert [event["event_type"] for event in trace_events] == [
|
|
||||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
|
||||||
trace.ConversationTraceEventType.TOOL_CALL,
|
|
||||||
]
|
|
||||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
|
||||||
detail_event = trace_events[1]
|
|
||||||
assert "Answer in plain text" in detail_event["data"]["system"]
|
|
||||||
assert "Today's date is 2024-06-03." in trace_events[1]["data"]["system"]
|
|
||||||
|
|
||||||
# Call it again, make sure we have updated prompt
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"anthropic.resources.messages.AsyncMessages.create",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=completion_result,
|
|
||||||
) as mock_create,
|
|
||||||
freeze_time("2024-06-04 23:00:00"),
|
|
||||||
):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"Please call the test function",
|
|
||||||
None,
|
|
||||||
context,
|
|
||||||
agent_id=agent_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Today's date is 2024-06-04." in mock_create.mock_calls[1][2]["system"]
|
|
||||||
# Test old assert message not updated
|
|
||||||
assert "Today's date is 2024-06-03." in trace_events[1]["data"]["system"]
|
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools")
|
@patch("homeassistant.components.anthropic.conversation.llm.AssistAPI._async_get_tools")
|
||||||
async def test_function_exception(
|
async def test_function_exception(
|
||||||
@ -448,7 +411,7 @@ async def test_unknown_hass_api(
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass, "hello", None, Context(), agent_id="conversation.claude"
|
hass, "hello", "1234", Context(), agent_id="conversation.claude"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == snapshot
|
assert result == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user