Push more of the LLM conversation agent loop into ChatSession (#136602)

* Push more of the LLM conversation agent loop into ChatSession

* Revert unnecessary changes

* Revert changes to agent id filtering
This commit is contained in:
Allen Porter 2025-01-26 19:16:19 -08:00 committed by GitHub
parent dfbb48552c
commit 69938545df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 202 additions and 57 deletions

View File

@ -9,7 +9,8 @@ from typing import Any
import voluptuous as vol
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import config_validation as cv, singleton
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, singleton
from .const import (
DATA_COMPONENT,
@ -109,7 +110,19 @@ async def async_converse(
dataclasses.asdict(conversation_input),
)
)
result = await method(conversation_input)
try:
result = await method(conversation_input)
except HomeAssistantError as err:
intent_response = intent.IntentResponse(language=language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
str(err),
)
result = ConversationResult(
response=intent_response,
conversation_id=conversation_id,
)
trace.set_result(**result.as_dict())
return result

View File

@ -9,6 +9,8 @@ from datetime import datetime, timedelta
import logging
from typing import Literal
import voluptuous as vol
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
CALLBACK_TYPE,
@ -23,7 +25,9 @@ from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.event import async_call_later
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult
@ -120,7 +124,7 @@ async def async_get_chat_session(
if history:
history = replace(history, messages=history.messages.copy())
else:
history = ChatSession(hass, conversation_id)
history = ChatSession(hass, conversation_id, user_input.agent_id)
message: ChatMessage = ChatMessage(
role="user",
@ -190,6 +194,7 @@ class ChatSession[_NativeT]:
hass: HomeAssistant
conversation_id: str
agent_id: str | None
user_name: str | None = None
messages: list[ChatMessage[_NativeT]] = field(
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
@ -209,7 +214,9 @@ class ChatSession[_NativeT]:
self.messages.append(message)
@callback
def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]:
def async_get_messages(
self, agent_id: str | None = None
) -> list[ChatMessage[_NativeT]]:
"""Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs.
@ -326,3 +333,29 @@ class ChatSession[_NativeT]:
agent_id=user_input.agent_id,
content=prompt,
)
LOGGER.debug("Prompt: %s", self.messages)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
"messages": self.messages,
"tools": self.llm_api.tools if self.llm_api else None,
},
)
async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
"""Invoke LLM tool for the configured LLM API."""
if not self.llm_api:
raise ValueError("No LLM API configured")
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
try:
tool_response = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
return tool_response

View File

@ -16,11 +16,9 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
@ -94,6 +92,19 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
return param
def _chat_message_convert(
message: conversation.ChatMessage[ChatCompletionMessageParam],
agent_id: str | None,
) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format."""
if message.native is not None and message.agent_id == agent_id:
return message.native
return cast(
ChatCompletionMessageParam,
{"role": message.role, "content": message.content},
)
class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -173,27 +184,10 @@ class OpenAIConversationEntity(
for tool in session.llm_api.tools
]
messages: list[ChatCompletionMessageParam] = []
for message in session.async_get_messages(user_input.agent_id):
if message.native is not None and message.agent_id == user_input.agent_id:
messages.append(message.native)
else:
messages.append(
cast(
ChatCompletionMessageParam,
{"role": message.role, "content": message.content},
)
)
LOGGER.debug("Prompt: %s", messages)
LOGGER.debug("Tools: %s", tools)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
"messages": session.messages,
"tools": session.llm_api.tools if session.llm_api else None,
},
)
messages = [
_chat_message_convert(message, user_input.agent_id)
for message in session.async_get_messages()
]
client = self.entry.runtime_data
@ -211,14 +205,7 @@ class OpenAIConversationEntity(
)
except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to OpenAI",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=session.conversation_id
)
raise HomeAssistantError("Error talking to OpenAI") from err
LOGGER.debug("Response %s", result)
response = result.choices[0].message
@ -241,18 +228,7 @@ class OpenAIConversationEntity(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_response = await session.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
tool_response = await session.async_call_tool(tool_input)
messages.append(
ChatCompletionToolMessageParam(
role="tool",

View File

@ -2,13 +2,15 @@
from collections.abc import Generator
from datetime import timedelta
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components.conversation import ConversationInput, session
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.util import dt as dt_util
@ -182,7 +184,7 @@ async def test_message_filtering(
)
assert messages[1] == session.ChatMessage(
role="user",
agent_id=mock_conversation_input.agent_id,
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
@ -203,7 +205,7 @@ async def test_message_filtering(
native="assistant-reply-native",
)
)
# Different agent, will be filtered out.
# Different agent, native messages will be filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="another-mock-agent-id", content="", native=1
@ -214,11 +216,20 @@ async def test_message_filtering(
role="native", agent_id="mock-agent-id", content="", native=1
)
)
# A non-native message from another agent is not filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
native=1,
)
)
assert len(chat_session.messages) == 5
assert len(chat_session.messages) == 6
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 4
assert len(messages) == 5
assert messages[2] == session.ChatMessage(
role="assistant",
@ -229,6 +240,9 @@ async def test_message_filtering(
assert messages[3] == session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
assert messages[4] == session.ChatMessage(
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
)
async def test_llm_api(
@ -413,3 +427,81 @@ async def test_extra_systen_prompt(
assert chat_session.extra_system_prompt == extra_system_prompt2
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
async def test_tool_call(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test using the session tool calling API."""
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"
with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_session.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
)
assert result == "Test response"
async def test_tool_call_exception(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test using the session tool calling API."""
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.side_effect = HomeAssistantError("Test error")
with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_session.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
)
assert result == {"error": "HomeAssistantError", "error_text": "Test error"}

View File

@ -61,18 +61,18 @@ async def test_converation_trace(
}
async def test_converation_trace_error(
async def test_converation_trace_uncaught_error(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
"""Test tracing a conversation that raises an uncaught error."""
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
side_effect=HomeAssistantError("Failed to talk to agent"),
side_effect=ValueError("Unexpected error"),
),
pytest.raises(HomeAssistantError),
pytest.raises(ValueError),
):
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
@ -87,4 +87,35 @@ async def test_converation_trace_error(
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert last_trace.get("error") == "Failed to talk to agent"
assert last_trace.get("error") == "Unexpected error"
assert not last_trace.get("result")
async def test_converation_trace_homeassistant_error(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation with a HomeAssistant error."""
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
side_effect=HomeAssistantError("Failed to talk to agent"),
),
):
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
result = last_trace.get("result")
assert result
assert result["response"]["speech"]["plain"]["speech"] == "Failed to talk to agent"