Conversation chat log cleanup and optimization (#137784)

This commit is contained in:
Paulus Schoutsen 2025-02-08 01:06:16 -05:00 committed by GitHub
parent aa19207ea4
commit f64b494282
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 63 deletions

View File

@ -1093,16 +1093,18 @@ class PipelineRun:
agent_id = conversation.HOME_ASSISTANT_AGENT agent_id = conversation.HOME_ASSISTANT_AGENT
processed_locally = True processed_locally = True
# It was already handled, create response and add to chat history with (
if intent_response is not None: chat_session.async_get_chat_session(
with ( self.hass, user_input.conversation_id
chat_session.async_get_chat_session( ) as session,
self.hass, user_input.conversation_id conversation.async_get_chat_log(
) as session, self.hass,
conversation.async_get_chat_log( session,
self.hass, session, user_input user_input,
) as chat_log, ) as chat_log,
): ):
# It was already handled, create response and add to chat history
if intent_response is not None:
speech: str = intent_response.speech.get("plain", {}).get( speech: str = intent_response.speech.get("plain", {}).get(
"speech", "" "speech", ""
) )
@ -1117,21 +1119,21 @@ class PipelineRun:
conversation_id=session.conversation_id, conversation_id=session.conversation_id,
) )
else: else:
# Fall back to pipeline conversation agent # Fall back to pipeline conversation agent
conversation_result = await conversation.async_converse( conversation_result = await conversation.async_converse(
hass=self.hass, hass=self.hass,
text=user_input.text, text=user_input.text,
conversation_id=user_input.conversation_id, conversation_id=user_input.conversation_id,
device_id=user_input.device_id, device_id=user_input.device_id,
context=user_input.context, context=user_input.context,
language=user_input.language, language=user_input.language,
agent_id=user_input.agent_id, agent_id=user_input.agent_id,
extra_system_prompt=user_input.extra_system_prompt, extra_system_prompt=user_input.extra_system_prompt,
) )
speech = conversation_result.response.speech.get("plain", {}).get( speech = conversation_result.response.speech.get("plain", {}).get(
"speech", "" "speech", ""
) )
except Exception as src_error: except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition") _LOGGER.exception("Unexpected error during intent recognition")

View File

@ -1,9 +1,11 @@
"""Conversation history.""" """Conversation chat log."""
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
from contextlib import contextmanager from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
import logging import logging
@ -19,10 +21,14 @@ from . import trace
from .const import DOMAIN from .const import DOMAIN
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log") DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
"current_chat_log", default=None
)
@contextmanager @contextmanager
def async_get_chat_log( def async_get_chat_log(
@ -31,41 +37,50 @@ def async_get_chat_log(
user_input: ConversationInput | None = None, user_input: ConversationInput | None = None,
) -> Generator[ChatLog]: ) -> Generator[ChatLog]:
"""Return chat log for a specific chat session.""" """Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY) if chat_log := current_chat_log.get():
if all_history is None: # If a chat log is already active and it's the requested conversation ID,
all_history = {} # return that. We won't update the last updated time in this case.
hass.data[DATA_CHAT_HISTORY] = all_history if chat_log.conversation_id == session.conversation_id:
yield chat_log
return
history = all_history.get(session.conversation_id) all_chat_logs = hass.data.get(DATA_CHAT_LOGS)
if all_chat_logs is None:
all_chat_logs = {}
hass.data[DATA_CHAT_LOGS] = all_chat_logs
if history: chat_log = all_chat_logs.get(session.conversation_id)
history = replace(history, content=history.content.copy())
if chat_log:
chat_log = replace(chat_log, content=chat_log.content.copy())
else: else:
history = ChatLog(hass, session.conversation_id) chat_log = ChatLog(hass, session.conversation_id)
if user_input is not None: if user_input is not None:
history.async_add_user_content(UserContent(content=user_input.text)) chat_log.async_add_user_content(UserContent(content=user_input.text))
last_message = history.content[-1] last_message = chat_log.content[-1]
yield history token = current_chat_log.set(chat_log)
yield chat_log
current_chat_log.reset(token)
if history.content[-1] is last_message: if chat_log.content[-1] is last_message:
LOGGER.debug( LOGGER.debug(
"History opened but no assistant message was added, ignoring update" "Chat Log opened but no assistant message was added, ignoring update"
) )
return return
if session.conversation_id not in all_history: if session.conversation_id not in all_chat_logs:
@callback @callback
def do_cleanup() -> None: def do_cleanup() -> None:
"""Handle cleanup.""" """Handle cleanup."""
all_history.pop(session.conversation_id) all_chat_logs.pop(session.conversation_id)
session.async_on_cleanup(do_cleanup) session.async_on_cleanup(do_cleanup)
all_history[session.conversation_id] = history all_chat_logs[session.conversation_id] = chat_log
class ConverseError(HomeAssistantError): class ConverseError(HomeAssistantError):
@ -112,7 +127,7 @@ class AssistantContent:
role: str = field(init=False, default="assistant") role: str = field(init=False, default="assistant")
agent_id: str agent_id: str
content: str content: str | None = None
tool_calls: list[llm.ToolInput] | None = None tool_calls: list[llm.ToolInput] | None = None
@ -143,6 +158,7 @@ class ChatLog:
@callback @callback
def async_add_user_content(self, content: UserContent) -> None: def async_add_user_content(self, content: UserContent) -> None:
"""Add user content to the log.""" """Add user content to the log."""
LOGGER.debug("Adding user content: %s", content)
self.content.append(content) self.content.append(content)
@callback @callback
@ -150,14 +166,24 @@ class ChatLog:
self, content: AssistantContent self, content: AssistantContent
) -> None: ) -> None:
"""Add assistant content to the log.""" """Add assistant content to the log."""
LOGGER.debug("Adding assistant content: %s", content)
if content.tool_calls is not None: if content.tool_calls is not None:
raise ValueError("Tool calls not allowed") raise ValueError("Tool calls not allowed")
self.content.append(content) self.content.append(content)
async def async_add_assistant_content( async def async_add_assistant_content(
self, content: AssistantContent self,
content: AssistantContent,
/,
tool_call_tasks: dict[str, asyncio.Task] | None = None,
) -> AsyncGenerator[ToolResultContent]: ) -> AsyncGenerator[ToolResultContent]:
"""Add assistant content.""" """Add assistant content and execute tool calls.
tool_call_tasks can contains tasks for tool calls that are already in progress.
This method is an async generator and will yield the tool results as they come in.
"""
LOGGER.debug("Adding assistant content: %s", content)
self.content.append(content) self.content.append(content)
if content.tool_calls is None: if content.tool_calls is None:
@ -166,13 +192,22 @@ class ChatLog:
if self.llm_api is None: if self.llm_api is None:
raise ValueError("No LLM API configured") raise ValueError("No LLM API configured")
if tool_call_tasks is None:
tool_call_tasks = {}
for tool_input in content.tool_calls:
if tool_input.id not in tool_call_tasks:
tool_call_tasks[tool_input.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_input),
name=f"llm_tool_{tool_input.id}",
)
for tool_input in content.tool_calls: for tool_input in content.tool_calls:
LOGGER.debug( LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
) )
try: try:
tool_result = await self.llm_api.async_call_tool(tool_input) tool_result = await tool_call_tasks[tool_input.id]
except (HomeAssistantError, vol.Invalid) as e: except (HomeAssistantError, vol.Invalid) as e:
tool_result = {"error": type(e).__name__} tool_result = {"error": type(e).__name__}
if str(e): if str(e):

View File

@ -15,7 +15,7 @@ from homeassistant.components.conversation import (
ToolResultContent, ToolResultContent,
async_get_chat_log, async_get_chat_log,
) )
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm from homeassistant.helpers import chat_session, llm
@ -63,7 +63,7 @@ async def test_cleanup(
) )
) )
assert conversation_id in hass.data[DATA_CHAT_HISTORY] assert conversation_id in hass.data[DATA_CHAT_LOGS]
# Set the last updated to be older than the timeout # Set the last updated to be older than the timeout
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = ( hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
@ -75,7 +75,7 @@ async def test_cleanup(
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1), dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
) )
assert conversation_id not in hass.data[DATA_CHAT_HISTORY] assert conversation_id not in hass.data[DATA_CHAT_LOGS]
async def test_default_content( async def test_default_content(
@ -279,9 +279,18 @@ async def test_extra_systen_prompt(
assert chat_log.content[0].content.endswith(extra_system_prompt2) assert chat_log.content[0].content.endswith(extra_system_prompt2)
@pytest.mark.parametrize(
"prerun_tool_tasks",
[
None,
("mock-tool-call-id",),
("mock-tool-call-id", "mock-tool-call-id-2"),
],
)
async def test_tool_call( async def test_tool_call(
hass: HomeAssistant, hass: HomeAssistant,
mock_conversation_input: ConversationInput, mock_conversation_input: ConversationInput,
prerun_tool_tasks: tuple[str] | None,
) -> None: ) -> None:
"""Test using the session tool calling API.""" """Test using the session tool calling API."""
@ -316,26 +325,47 @@ async def test_tool_call(
id="mock-tool-call-id", id="mock-tool-call-id",
tool_name="test_tool", tool_name="test_tool",
tool_args={"param1": "Test Param"}, tool_args={"param1": "Test Param"},
) ),
llm.ToolInput(
id="mock-tool-call-id-2",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
),
], ],
) )
tool_call_tasks = None
if prerun_tool_tasks:
tool_call_tasks = {
tool_call_id: hass.async_create_task(
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
tool_call_id,
)
for tool_call_id in prerun_tool_tasks
}
with pytest.raises(ValueError): with pytest.raises(ValueError):
chat_log.async_add_assistant_content_without_tools(content) chat_log.async_add_assistant_content_without_tools(content)
result = None results = [
async for tool_result_content in chat_log.async_add_assistant_content( tool_result_content
content async for tool_result_content in chat_log.async_add_assistant_content(
): content, tool_call_tasks=tool_call_tasks
assert result is None )
result = tool_result_content ]
assert result == ToolResultContent( assert results[0] == ToolResultContent(
agent_id=mock_conversation_input.agent_id, agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id", tool_call_id="mock-tool-call-id",
tool_result="Test response", tool_result="Test response",
tool_name="test_tool", tool_name="test_tool",
) )
assert results[1] == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id-2",
tool_result="Test response",
tool_name="test_tool",
)
async def test_tool_call_exception( async def test_tool_call_exception(