mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Conversation chat log cleanup and optimization (#137784)
This commit is contained in:
parent
aa19207ea4
commit
f64b494282
@ -1093,16 +1093,18 @@ class PipelineRun:
|
||||
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||
processed_locally = True
|
||||
|
||||
# It was already handled, create response and add to chat history
|
||||
if intent_response is not None:
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(
|
||||
self.hass, session, user_input
|
||||
) as chat_log,
|
||||
):
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
) as session,
|
||||
conversation.async_get_chat_log(
|
||||
self.hass,
|
||||
session,
|
||||
user_input,
|
||||
) as chat_log,
|
||||
):
|
||||
# It was already handled, create response and add to chat history
|
||||
if intent_response is not None:
|
||||
speech: str = intent_response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
@ -1117,21 +1119,21 @@ class PipelineRun:
|
||||
conversation_id=session.conversation_id,
|
||||
)
|
||||
|
||||
else:
|
||||
# Fall back to pipeline conversation agent
|
||||
conversation_result = await conversation.async_converse(
|
||||
hass=self.hass,
|
||||
text=user_input.text,
|
||||
conversation_id=user_input.conversation_id,
|
||||
device_id=user_input.device_id,
|
||||
context=user_input.context,
|
||||
language=user_input.language,
|
||||
agent_id=user_input.agent_id,
|
||||
extra_system_prompt=user_input.extra_system_prompt,
|
||||
)
|
||||
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
else:
|
||||
# Fall back to pipeline conversation agent
|
||||
conversation_result = await conversation.async_converse(
|
||||
hass=self.hass,
|
||||
text=user_input.text,
|
||||
conversation_id=user_input.conversation_id,
|
||||
device_id=user_input.device_id,
|
||||
context=user_input.context,
|
||||
language=user_input.language,
|
||||
agent_id=user_input.agent_id,
|
||||
extra_system_prompt=user_input.extra_system_prompt,
|
||||
)
|
||||
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during intent recognition")
|
||||
|
@ -1,9 +1,11 @@
|
||||
"""Conversation history."""
|
||||
"""Conversation chat log."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field, replace
|
||||
import logging
|
||||
|
||||
@ -19,10 +21,14 @@ from . import trace
|
||||
from .const import DOMAIN
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log")
|
||||
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
|
||||
"current_chat_log", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def async_get_chat_log(
|
||||
@ -31,41 +37,50 @@ def async_get_chat_log(
|
||||
user_input: ConversationInput | None = None,
|
||||
) -> Generator[ChatLog]:
|
||||
"""Return chat log for a specific chat session."""
|
||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
||||
if all_history is None:
|
||||
all_history = {}
|
||||
hass.data[DATA_CHAT_HISTORY] = all_history
|
||||
if chat_log := current_chat_log.get():
|
||||
# If a chat log is already active and it's the requested conversation ID,
|
||||
# return that. We won't update the last updated time in this case.
|
||||
if chat_log.conversation_id == session.conversation_id:
|
||||
yield chat_log
|
||||
return
|
||||
|
||||
history = all_history.get(session.conversation_id)
|
||||
all_chat_logs = hass.data.get(DATA_CHAT_LOGS)
|
||||
if all_chat_logs is None:
|
||||
all_chat_logs = {}
|
||||
hass.data[DATA_CHAT_LOGS] = all_chat_logs
|
||||
|
||||
if history:
|
||||
history = replace(history, content=history.content.copy())
|
||||
chat_log = all_chat_logs.get(session.conversation_id)
|
||||
|
||||
if chat_log:
|
||||
chat_log = replace(chat_log, content=chat_log.content.copy())
|
||||
else:
|
||||
history = ChatLog(hass, session.conversation_id)
|
||||
chat_log = ChatLog(hass, session.conversation_id)
|
||||
|
||||
if user_input is not None:
|
||||
history.async_add_user_content(UserContent(content=user_input.text))
|
||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
last_message = history.content[-1]
|
||||
last_message = chat_log.content[-1]
|
||||
|
||||
yield history
|
||||
token = current_chat_log.set(chat_log)
|
||||
yield chat_log
|
||||
current_chat_log.reset(token)
|
||||
|
||||
if history.content[-1] is last_message:
|
||||
if chat_log.content[-1] is last_message:
|
||||
LOGGER.debug(
|
||||
"History opened but no assistant message was added, ignoring update"
|
||||
"Chat Log opened but no assistant message was added, ignoring update"
|
||||
)
|
||||
return
|
||||
|
||||
if session.conversation_id not in all_history:
|
||||
if session.conversation_id not in all_chat_logs:
|
||||
|
||||
@callback
|
||||
def do_cleanup() -> None:
|
||||
"""Handle cleanup."""
|
||||
all_history.pop(session.conversation_id)
|
||||
all_chat_logs.pop(session.conversation_id)
|
||||
|
||||
session.async_on_cleanup(do_cleanup)
|
||||
|
||||
all_history[session.conversation_id] = history
|
||||
all_chat_logs[session.conversation_id] = chat_log
|
||||
|
||||
|
||||
class ConverseError(HomeAssistantError):
|
||||
@ -112,7 +127,7 @@ class AssistantContent:
|
||||
|
||||
role: str = field(init=False, default="assistant")
|
||||
agent_id: str
|
||||
content: str
|
||||
content: str | None = None
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
|
||||
|
||||
@ -143,6 +158,7 @@ class ChatLog:
|
||||
@callback
|
||||
def async_add_user_content(self, content: UserContent) -> None:
|
||||
"""Add user content to the log."""
|
||||
LOGGER.debug("Adding user content: %s", content)
|
||||
self.content.append(content)
|
||||
|
||||
@callback
|
||||
@ -150,14 +166,24 @@ class ChatLog:
|
||||
self, content: AssistantContent
|
||||
) -> None:
|
||||
"""Add assistant content to the log."""
|
||||
LOGGER.debug("Adding assistant content: %s", content)
|
||||
if content.tool_calls is not None:
|
||||
raise ValueError("Tool calls not allowed")
|
||||
self.content.append(content)
|
||||
|
||||
async def async_add_assistant_content(
|
||||
self, content: AssistantContent
|
||||
self,
|
||||
content: AssistantContent,
|
||||
/,
|
||||
tool_call_tasks: dict[str, asyncio.Task] | None = None,
|
||||
) -> AsyncGenerator[ToolResultContent]:
|
||||
"""Add assistant content."""
|
||||
"""Add assistant content and execute tool calls.
|
||||
|
||||
tool_call_tasks can contains tasks for tool calls that are already in progress.
|
||||
|
||||
This method is an async generator and will yield the tool results as they come in.
|
||||
"""
|
||||
LOGGER.debug("Adding assistant content: %s", content)
|
||||
self.content.append(content)
|
||||
|
||||
if content.tool_calls is None:
|
||||
@ -166,13 +192,22 @@ class ChatLog:
|
||||
if self.llm_api is None:
|
||||
raise ValueError("No LLM API configured")
|
||||
|
||||
if tool_call_tasks is None:
|
||||
tool_call_tasks = {}
|
||||
for tool_input in content.tool_calls:
|
||||
if tool_input.id not in tool_call_tasks:
|
||||
tool_call_tasks[tool_input.id] = self.hass.async_create_task(
|
||||
self.llm_api.async_call_tool(tool_input),
|
||||
name=f"llm_tool_{tool_input.id}",
|
||||
)
|
||||
|
||||
for tool_input in content.tool_calls:
|
||||
LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result = await self.llm_api.async_call_tool(tool_input)
|
||||
tool_result = await tool_call_tasks[tool_input.id]
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_result = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
|
@ -15,7 +15,7 @@ from homeassistant.components.conversation import (
|
||||
ToolResultContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
|
||||
from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, llm
|
||||
@ -63,7 +63,7 @@ async def test_cleanup(
|
||||
)
|
||||
)
|
||||
|
||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
||||
assert conversation_id in hass.data[DATA_CHAT_LOGS]
|
||||
|
||||
# Set the last updated to be older than the timeout
|
||||
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
||||
@ -75,7 +75,7 @@ async def test_cleanup(
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
||||
)
|
||||
|
||||
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
||||
assert conversation_id not in hass.data[DATA_CHAT_LOGS]
|
||||
|
||||
|
||||
async def test_default_content(
|
||||
@ -279,9 +279,18 @@ async def test_extra_systen_prompt(
|
||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prerun_tool_tasks",
|
||||
[
|
||||
None,
|
||||
("mock-tool-call-id",),
|
||||
("mock-tool-call-id", "mock-tool-call-id-2"),
|
||||
],
|
||||
)
|
||||
async def test_tool_call(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
prerun_tool_tasks: tuple[str] | None,
|
||||
) -> None:
|
||||
"""Test using the session tool calling API."""
|
||||
|
||||
@ -316,26 +325,47 @@ async def test_tool_call(
|
||||
id="mock-tool-call-id",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
)
|
||||
),
|
||||
llm.ToolInput(
|
||||
id="mock-tool-call-id-2",
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "Test Param"},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
tool_call_tasks = None
|
||||
if prerun_tool_tasks:
|
||||
tool_call_tasks = {
|
||||
tool_call_id: hass.async_create_task(
|
||||
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
|
||||
tool_call_id,
|
||||
)
|
||||
for tool_call_id in prerun_tool_tasks
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
chat_log.async_add_assistant_content_without_tools(content)
|
||||
|
||||
result = None
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
content
|
||||
):
|
||||
assert result is None
|
||||
result = tool_result_content
|
||||
results = [
|
||||
tool_result_content
|
||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||
content, tool_call_tasks=tool_call_tasks
|
||||
)
|
||||
]
|
||||
|
||||
assert result == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
tool_call_id="mock-tool-call-id",
|
||||
tool_result="Test response",
|
||||
tool_name="test_tool",
|
||||
)
|
||||
assert results[0] == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
tool_call_id="mock-tool-call-id",
|
||||
tool_result="Test response",
|
||||
tool_name="test_tool",
|
||||
)
|
||||
assert results[1] == ToolResultContent(
|
||||
agent_id=mock_conversation_input.agent_id,
|
||||
tool_call_id="mock-tool-call-id-2",
|
||||
tool_result="Test response",
|
||||
tool_name="test_tool",
|
||||
)
|
||||
|
||||
|
||||
async def test_tool_call_exception(
|
||||
|
Loading…
x
Reference in New Issue
Block a user