mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Push more of the LLM conversation agent loop into ChatSession (#136602)
* Push more of the LLM conversation agent loop into ChatSession * Revert unnecessary changes * Revert changes to agent id filtering
This commit is contained in:
parent
dfbb48552c
commit
69938545df
@ -9,7 +9,8 @@ from typing import Any
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
|
||||
from homeassistant.helpers import config_validation as cv, singleton
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv, intent, singleton
|
||||
|
||||
from .const import (
|
||||
DATA_COMPONENT,
|
||||
@ -109,7 +110,19 @@ async def async_converse(
|
||||
dataclasses.asdict(conversation_input),
|
||||
)
|
||||
)
|
||||
result = await method(conversation_input)
|
||||
try:
|
||||
result = await method(conversation_input)
|
||||
except HomeAssistantError as err:
|
||||
intent_response = intent.IntentResponse(language=language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
str(err),
|
||||
)
|
||||
result = ConversationResult(
|
||||
response=intent_response,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
trace.set_result(**result.as_dict())
|
||||
return result
|
||||
|
||||
|
@ -9,6 +9,8 @@ from datetime import datetime, timedelta
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
@ -23,7 +25,9 @@ from homeassistant.helpers import intent, llm, template
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from . import trace
|
||||
from .const import DOMAIN
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
@ -120,7 +124,7 @@ async def async_get_chat_session(
|
||||
if history:
|
||||
history = replace(history, messages=history.messages.copy())
|
||||
else:
|
||||
history = ChatSession(hass, conversation_id)
|
||||
history = ChatSession(hass, conversation_id, user_input.agent_id)
|
||||
|
||||
message: ChatMessage = ChatMessage(
|
||||
role="user",
|
||||
@ -190,6 +194,7 @@ class ChatSession[_NativeT]:
|
||||
|
||||
hass: HomeAssistant
|
||||
conversation_id: str
|
||||
agent_id: str | None
|
||||
user_name: str | None = None
|
||||
messages: list[ChatMessage[_NativeT]] = field(
|
||||
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
|
||||
@ -209,7 +214,9 @@ class ChatSession[_NativeT]:
|
||||
self.messages.append(message)
|
||||
|
||||
@callback
|
||||
def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]:
|
||||
def async_get_messages(
|
||||
self, agent_id: str | None = None
|
||||
) -> list[ChatMessage[_NativeT]]:
|
||||
"""Get messages for a specific agent ID.
|
||||
|
||||
This will filter out any native message tied to other agent IDs.
|
||||
@ -326,3 +333,29 @@ class ChatSession[_NativeT]:
|
||||
agent_id=user_input.agent_id,
|
||||
content=prompt,
|
||||
)
|
||||
|
||||
LOGGER.debug("Prompt: %s", self.messages)
|
||||
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
||||
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{
|
||||
"messages": self.messages,
|
||||
"tools": self.llm_api.tools if self.llm_api else None,
|
||||
},
|
||||
)
|
||||
|
||||
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
|
||||
"""Invoke LLM tool for the configured LLM API."""
|
||||
if not self.llm_api:
|
||||
raise ValueError("No LLM API configured")
|
||||
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
||||
|
||||
try:
|
||||
tool_response = await self.llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
LOGGER.debug("Tool response: %s", tool_response)
|
||||
return tool_response
|
||||
|
@ -16,11 +16,9 @@ from openai.types.chat import (
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -94,6 +92,19 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
|
||||
return param
|
||||
|
||||
|
||||
def _chat_message_convert(
|
||||
message: conversation.ChatMessage[ChatCompletionMessageParam],
|
||||
agent_id: str | None,
|
||||
) -> ChatCompletionMessageParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
if message.native is not None and message.agent_id == agent_id:
|
||||
return message.native
|
||||
return cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": message.role, "content": message.content},
|
||||
)
|
||||
|
||||
|
||||
class OpenAIConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
@ -173,27 +184,10 @@ class OpenAIConversationEntity(
|
||||
for tool in session.llm_api.tools
|
||||
]
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
for message in session.async_get_messages(user_input.agent_id):
|
||||
if message.native is not None and message.agent_id == user_input.agent_id:
|
||||
messages.append(message.native)
|
||||
else:
|
||||
messages.append(
|
||||
cast(
|
||||
ChatCompletionMessageParam,
|
||||
{"role": message.role, "content": message.content},
|
||||
)
|
||||
)
|
||||
|
||||
LOGGER.debug("Prompt: %s", messages)
|
||||
LOGGER.debug("Tools: %s", tools)
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{
|
||||
"messages": session.messages,
|
||||
"tools": session.llm_api.tools if session.llm_api else None,
|
||||
},
|
||||
)
|
||||
messages = [
|
||||
_chat_message_convert(message, user_input.agent_id)
|
||||
for message in session.async_get_messages()
|
||||
]
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
@ -211,14 +205,7 @@ class OpenAIConversationEntity(
|
||||
)
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
"Sorry, I had a problem talking to OpenAI",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=session.conversation_id
|
||||
)
|
||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||
|
||||
LOGGER.debug("Response %s", result)
|
||||
response = result.choices[0].message
|
||||
@ -241,18 +228,7 @@ class OpenAIConversationEntity(
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
|
||||
try:
|
||||
tool_response = await session.llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
|
||||
LOGGER.debug("Tool response: %s", tool_response)
|
||||
tool_response = await session.async_call_tool(tool_input)
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
|
@ -2,13 +2,15 @@
|
||||
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.conversation import ConversationInput, session
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
@ -182,7 +184,7 @@ async def test_message_filtering(
|
||||
)
|
||||
assert messages[1] == session.ChatMessage(
|
||||
role="user",
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
agent_id="mock-agent-id",
|
||||
content=mock_conversation_input.text,
|
||||
)
|
||||
# Cannot add a second user message in a row
|
||||
@ -203,7 +205,7 @@ async def test_message_filtering(
|
||||
native="assistant-reply-native",
|
||||
)
|
||||
)
|
||||
# Different agent, will be filtered out.
|
||||
# Different agent, native messages will be filtered out.
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="native", agent_id="another-mock-agent-id", content="", native=1
|
||||
@ -214,11 +216,20 @@ async def test_message_filtering(
|
||||
role="native", agent_id="mock-agent-id", content="", native=1
|
||||
)
|
||||
)
|
||||
# A non-native message from another agent is not filtered out.
|
||||
chat_session.async_add_message(
|
||||
session.ChatMessage(
|
||||
role="assistant",
|
||||
agent_id="another-mock-agent-id",
|
||||
content="Hi!",
|
||||
native=1,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(chat_session.messages) == 5
|
||||
assert len(chat_session.messages) == 6
|
||||
|
||||
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
|
||||
assert len(messages) == 4
|
||||
assert len(messages) == 5
|
||||
|
||||
assert messages[2] == session.ChatMessage(
|
||||
role="assistant",
|
||||
@ -229,6 +240,9 @@ async def test_message_filtering(
|
||||
assert messages[3] == session.ChatMessage(
|
||||
role="native", agent_id="mock-agent-id", content="", native=1
|
||||
)
|
||||
assert messages[4] == session.ChatMessage(
|
||||
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
|
||||
)
|
||||
|
||||
|
||||
async def test_llm_api(
|
||||
@ -413,3 +427,81 @@ async def test_extra_systen_prompt(
|
||||
|
||||
assert chat_session.extra_system_prompt == extra_system_prompt2
|
||||
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
|
||||
|
||||
|
||||
async def test_tool_call(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test using the session tool calling API."""
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
||||
return_value=[],
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_session.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result == "Test response"
|
||||
|
||||
|
||||
async def test_tool_call_exception(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test using the session tool calling API."""
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
|
||||
return_value=[],
|
||||
) as mock_get_tools:
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
async with session.async_get_chat_session(
|
||||
hass, mock_conversation_input
|
||||
) as chat_session:
|
||||
await chat_session.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api="assist",
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
result = await chat_session.async_call_tool(
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
)
|
||||
|
||||
assert result == {"error": "HomeAssistantError", "error_text": "Test error"}
|
||||
|
@ -61,18 +61,18 @@ async def test_converation_trace(
|
||||
}
|
||||
|
||||
|
||||
async def test_converation_trace_error(
|
||||
async def test_converation_trace_uncaught_error(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
sl_setup: None,
|
||||
) -> None:
|
||||
"""Test tracing a conversation."""
|
||||
"""Test tracing a conversation that raises an uncaught error."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
|
||||
side_effect=HomeAssistantError("Failed to talk to agent"),
|
||||
side_effect=ValueError("Unexpected error"),
|
||||
),
|
||||
pytest.raises(HomeAssistantError),
|
||||
pytest.raises(ValueError),
|
||||
):
|
||||
await conversation.async_converse(
|
||||
hass, "add apples to my shopping list", None, Context()
|
||||
@ -87,4 +87,35 @@ async def test_converation_trace_error(
|
||||
assert (
|
||||
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||
)
|
||||
assert last_trace.get("error") == "Failed to talk to agent"
|
||||
assert last_trace.get("error") == "Unexpected error"
|
||||
assert not last_trace.get("result")
|
||||
|
||||
|
||||
async def test_converation_trace_homeassistant_error(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
sl_setup: None,
|
||||
) -> None:
|
||||
"""Test tracing a conversation with a HomeAssistant error."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
|
||||
side_effect=HomeAssistantError("Failed to talk to agent"),
|
||||
),
|
||||
):
|
||||
await conversation.async_converse(
|
||||
hass, "add apples to my shopping list", None, Context()
|
||||
)
|
||||
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
assert last_trace.get("events")
|
||||
assert len(last_trace.get("events")) == 1
|
||||
trace_event = last_trace["events"][0]
|
||||
assert (
|
||||
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||
)
|
||||
result = last_trace.get("result")
|
||||
assert result
|
||||
assert result["response"]["speech"]["plain"]["speech"] == "Failed to talk to agent"
|
||||
|
Loading…
x
Reference in New Issue
Block a user