Conversation chat log cleanup and optimization (#137784)

This commit is contained in:
Paulus Schoutsen 2025-02-08 01:06:16 -05:00 committed by GitHub
parent aa19207ea4
commit f64b494282
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 63 deletions

View File

@ -1093,16 +1093,18 @@ class PipelineRun:
agent_id = conversation.HOME_ASSISTANT_AGENT
processed_locally = True
# It was already handled, create response and add to chat history
if intent_response is not None:
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(
self.hass, session, user_input
) as chat_log,
):
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(
self.hass,
session,
user_input,
) as chat_log,
):
# It was already handled, create response and add to chat history
if intent_response is not None:
speech: str = intent_response.speech.get("plain", {}).get(
"speech", ""
)
@ -1117,21 +1119,21 @@ class PipelineRun:
conversation_id=session.conversation_id,
)
else:
# Fall back to pipeline conversation agent
conversation_result = await conversation.async_converse(
hass=self.hass,
text=user_input.text,
conversation_id=user_input.conversation_id,
device_id=user_input.device_id,
context=user_input.context,
language=user_input.language,
agent_id=user_input.agent_id,
extra_system_prompt=user_input.extra_system_prompt,
)
speech = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
else:
# Fall back to pipeline conversation agent
conversation_result = await conversation.async_converse(
hass=self.hass,
text=user_input.text,
conversation_id=user_input.conversation_id,
device_id=user_input.device_id,
context=user_input.context,
language=user_input.language,
agent_id=user_input.agent_id,
extra_system_prompt=user_input.extra_system_prompt,
)
speech = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition")

View File

@ -1,9 +1,11 @@
"""Conversation history."""
"""Conversation chat log."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field, replace
import logging
@ -19,10 +21,14 @@ from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult
DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log")
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
LOGGER = logging.getLogger(__name__)
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
"current_chat_log", default=None
)
@contextmanager
def async_get_chat_log(
@ -31,41 +37,50 @@ def async_get_chat_log(
user_input: ConversationInput | None = None,
) -> Generator[ChatLog]:
"""Return chat log for a specific chat session."""
all_history = hass.data.get(DATA_CHAT_HISTORY)
if all_history is None:
all_history = {}
hass.data[DATA_CHAT_HISTORY] = all_history
if chat_log := current_chat_log.get():
# If a chat log is already active and it's the requested conversation ID,
# return that. We won't update the last updated time in this case.
if chat_log.conversation_id == session.conversation_id:
yield chat_log
return
history = all_history.get(session.conversation_id)
all_chat_logs = hass.data.get(DATA_CHAT_LOGS)
if all_chat_logs is None:
all_chat_logs = {}
hass.data[DATA_CHAT_LOGS] = all_chat_logs
if history:
history = replace(history, content=history.content.copy())
chat_log = all_chat_logs.get(session.conversation_id)
if chat_log:
chat_log = replace(chat_log, content=chat_log.content.copy())
else:
history = ChatLog(hass, session.conversation_id)
chat_log = ChatLog(hass, session.conversation_id)
if user_input is not None:
history.async_add_user_content(UserContent(content=user_input.text))
chat_log.async_add_user_content(UserContent(content=user_input.text))
last_message = history.content[-1]
last_message = chat_log.content[-1]
yield history
token = current_chat_log.set(chat_log)
yield chat_log
current_chat_log.reset(token)
if history.content[-1] is last_message:
if chat_log.content[-1] is last_message:
LOGGER.debug(
"History opened but no assistant message was added, ignoring update"
"Chat Log opened but no assistant message was added, ignoring update"
)
return
if session.conversation_id not in all_history:
if session.conversation_id not in all_chat_logs:
@callback
def do_cleanup() -> None:
"""Handle cleanup."""
all_history.pop(session.conversation_id)
all_chat_logs.pop(session.conversation_id)
session.async_on_cleanup(do_cleanup)
all_history[session.conversation_id] = history
all_chat_logs[session.conversation_id] = chat_log
class ConverseError(HomeAssistantError):
@ -112,7 +127,7 @@ class AssistantContent:
role: str = field(init=False, default="assistant")
agent_id: str
content: str
content: str | None = None
tool_calls: list[llm.ToolInput] | None = None
@ -143,6 +158,7 @@ class ChatLog:
@callback
def async_add_user_content(self, content: UserContent) -> None:
"""Add user content to the log."""
LOGGER.debug("Adding user content: %s", content)
self.content.append(content)
@callback
@ -150,14 +166,24 @@ class ChatLog:
self, content: AssistantContent
) -> None:
"""Add assistant content to the log."""
LOGGER.debug("Adding assistant content: %s", content)
if content.tool_calls is not None:
raise ValueError("Tool calls not allowed")
self.content.append(content)
async def async_add_assistant_content(
self, content: AssistantContent
self,
content: AssistantContent,
/,
tool_call_tasks: dict[str, asyncio.Task] | None = None,
) -> AsyncGenerator[ToolResultContent]:
"""Add assistant content."""
"""Add assistant content and execute tool calls.
tool_call_tasks can contains tasks for tool calls that are already in progress.
This method is an async generator and will yield the tool results as they come in.
"""
LOGGER.debug("Adding assistant content: %s", content)
self.content.append(content)
if content.tool_calls is None:
@ -166,13 +192,22 @@ class ChatLog:
if self.llm_api is None:
raise ValueError("No LLM API configured")
if tool_call_tasks is None:
tool_call_tasks = {}
for tool_input in content.tool_calls:
if tool_input.id not in tool_call_tasks:
tool_call_tasks[tool_input.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_input),
name=f"llm_tool_{tool_input.id}",
)
for tool_input in content.tool_calls:
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_result = await self.llm_api.async_call_tool(tool_input)
tool_result = await tool_call_tasks[tool_input.id]
except (HomeAssistantError, vol.Invalid) as e:
tool_result = {"error": type(e).__name__}
if str(e):

View File

@ -15,7 +15,7 @@ from homeassistant.components.conversation import (
ToolResultContent,
async_get_chat_log,
)
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm
@ -63,7 +63,7 @@ async def test_cleanup(
)
)
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
assert conversation_id in hass.data[DATA_CHAT_LOGS]
# Set the last updated to be older than the timeout
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
@ -75,7 +75,7 @@ async def test_cleanup(
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
)
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
assert conversation_id not in hass.data[DATA_CHAT_LOGS]
async def test_default_content(
@ -279,9 +279,18 @@ async def test_extra_systen_prompt(
assert chat_log.content[0].content.endswith(extra_system_prompt2)
@pytest.mark.parametrize(
"prerun_tool_tasks",
[
None,
("mock-tool-call-id",),
("mock-tool-call-id", "mock-tool-call-id-2"),
],
)
async def test_tool_call(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
prerun_tool_tasks: tuple[str] | None,
) -> None:
"""Test using the session tool calling API."""
@ -316,26 +325,47 @@ async def test_tool_call(
id="mock-tool-call-id",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
),
llm.ToolInput(
id="mock-tool-call-id-2",
tool_name="test_tool",
tool_args={"param1": "Test Param"},
),
],
)
tool_call_tasks = None
if prerun_tool_tasks:
tool_call_tasks = {
tool_call_id: hass.async_create_task(
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
tool_call_id,
)
for tool_call_id in prerun_tool_tasks
}
with pytest.raises(ValueError):
chat_log.async_add_assistant_content_without_tools(content)
result = None
async for tool_result_content in chat_log.async_add_assistant_content(
content
):
assert result is None
result = tool_result_content
results = [
tool_result_content
async for tool_result_content in chat_log.async_add_assistant_content(
content, tool_call_tasks=tool_call_tasks
)
]
assert result == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result="Test response",
tool_name="test_tool",
)
assert results[0] == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id",
tool_result="Test response",
tool_name="test_tool",
)
assert results[1] == ToolResultContent(
agent_id=mock_conversation_input.agent_id,
tool_call_id="mock-tool-call-id-2",
tool_result="Test response",
tool_name="test_tool",
)
async def test_tool_call_exception(