mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Conversation chat log cleanup and optimization (#137784)
This commit is contained in:
parent
aa19207ea4
commit
f64b494282
@ -1093,16 +1093,18 @@ class PipelineRun:
|
|||||||
agent_id = conversation.HOME_ASSISTANT_AGENT
|
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||||
processed_locally = True
|
processed_locally = True
|
||||||
|
|
||||||
# It was already handled, create response and add to chat history
|
with (
|
||||||
if intent_response is not None:
|
chat_session.async_get_chat_session(
|
||||||
with (
|
self.hass, user_input.conversation_id
|
||||||
chat_session.async_get_chat_session(
|
) as session,
|
||||||
self.hass, user_input.conversation_id
|
conversation.async_get_chat_log(
|
||||||
) as session,
|
self.hass,
|
||||||
conversation.async_get_chat_log(
|
session,
|
||||||
self.hass, session, user_input
|
user_input,
|
||||||
) as chat_log,
|
) as chat_log,
|
||||||
):
|
):
|
||||||
|
# It was already handled, create response and add to chat history
|
||||||
|
if intent_response is not None:
|
||||||
speech: str = intent_response.speech.get("plain", {}).get(
|
speech: str = intent_response.speech.get("plain", {}).get(
|
||||||
"speech", ""
|
"speech", ""
|
||||||
)
|
)
|
||||||
@ -1117,21 +1119,21 @@ class PipelineRun:
|
|||||||
conversation_id=session.conversation_id,
|
conversation_id=session.conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fall back to pipeline conversation agent
|
# Fall back to pipeline conversation agent
|
||||||
conversation_result = await conversation.async_converse(
|
conversation_result = await conversation.async_converse(
|
||||||
hass=self.hass,
|
hass=self.hass,
|
||||||
text=user_input.text,
|
text=user_input.text,
|
||||||
conversation_id=user_input.conversation_id,
|
conversation_id=user_input.conversation_id,
|
||||||
device_id=user_input.device_id,
|
device_id=user_input.device_id,
|
||||||
context=user_input.context,
|
context=user_input.context,
|
||||||
language=user_input.language,
|
language=user_input.language,
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
extra_system_prompt=user_input.extra_system_prompt,
|
extra_system_prompt=user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
speech = conversation_result.response.speech.get("plain", {}).get(
|
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||||
"speech", ""
|
"speech", ""
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as src_error:
|
except Exception as src_error:
|
||||||
_LOGGER.exception("Unexpected error during intent recognition")
|
_LOGGER.exception("Unexpected error during intent recognition")
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
"""Conversation history."""
|
"""Conversation chat log."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -19,10 +21,14 @@ from . import trace
|
|||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .models import ConversationInput, ConversationResult
|
from .models import ConversationInput, ConversationResult
|
||||||
|
|
||||||
DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log")
|
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
|
||||||
|
"current_chat_log", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def async_get_chat_log(
|
def async_get_chat_log(
|
||||||
@ -31,41 +37,50 @@ def async_get_chat_log(
|
|||||||
user_input: ConversationInput | None = None,
|
user_input: ConversationInput | None = None,
|
||||||
) -> Generator[ChatLog]:
|
) -> Generator[ChatLog]:
|
||||||
"""Return chat log for a specific chat session."""
|
"""Return chat log for a specific chat session."""
|
||||||
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
if chat_log := current_chat_log.get():
|
||||||
if all_history is None:
|
# If a chat log is already active and it's the requested conversation ID,
|
||||||
all_history = {}
|
# return that. We won't update the last updated time in this case.
|
||||||
hass.data[DATA_CHAT_HISTORY] = all_history
|
if chat_log.conversation_id == session.conversation_id:
|
||||||
|
yield chat_log
|
||||||
|
return
|
||||||
|
|
||||||
history = all_history.get(session.conversation_id)
|
all_chat_logs = hass.data.get(DATA_CHAT_LOGS)
|
||||||
|
if all_chat_logs is None:
|
||||||
|
all_chat_logs = {}
|
||||||
|
hass.data[DATA_CHAT_LOGS] = all_chat_logs
|
||||||
|
|
||||||
if history:
|
chat_log = all_chat_logs.get(session.conversation_id)
|
||||||
history = replace(history, content=history.content.copy())
|
|
||||||
|
if chat_log:
|
||||||
|
chat_log = replace(chat_log, content=chat_log.content.copy())
|
||||||
else:
|
else:
|
||||||
history = ChatLog(hass, session.conversation_id)
|
chat_log = ChatLog(hass, session.conversation_id)
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
history.async_add_user_content(UserContent(content=user_input.text))
|
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||||
|
|
||||||
last_message = history.content[-1]
|
last_message = chat_log.content[-1]
|
||||||
|
|
||||||
yield history
|
token = current_chat_log.set(chat_log)
|
||||||
|
yield chat_log
|
||||||
|
current_chat_log.reset(token)
|
||||||
|
|
||||||
if history.content[-1] is last_message:
|
if chat_log.content[-1] is last_message:
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"History opened but no assistant message was added, ignoring update"
|
"Chat Log opened but no assistant message was added, ignoring update"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if session.conversation_id not in all_history:
|
if session.conversation_id not in all_chat_logs:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def do_cleanup() -> None:
|
def do_cleanup() -> None:
|
||||||
"""Handle cleanup."""
|
"""Handle cleanup."""
|
||||||
all_history.pop(session.conversation_id)
|
all_chat_logs.pop(session.conversation_id)
|
||||||
|
|
||||||
session.async_on_cleanup(do_cleanup)
|
session.async_on_cleanup(do_cleanup)
|
||||||
|
|
||||||
all_history[session.conversation_id] = history
|
all_chat_logs[session.conversation_id] = chat_log
|
||||||
|
|
||||||
|
|
||||||
class ConverseError(HomeAssistantError):
|
class ConverseError(HomeAssistantError):
|
||||||
@ -112,7 +127,7 @@ class AssistantContent:
|
|||||||
|
|
||||||
role: str = field(init=False, default="assistant")
|
role: str = field(init=False, default="assistant")
|
||||||
agent_id: str
|
agent_id: str
|
||||||
content: str
|
content: str | None = None
|
||||||
tool_calls: list[llm.ToolInput] | None = None
|
tool_calls: list[llm.ToolInput] | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -143,6 +158,7 @@ class ChatLog:
|
|||||||
@callback
|
@callback
|
||||||
def async_add_user_content(self, content: UserContent) -> None:
|
def async_add_user_content(self, content: UserContent) -> None:
|
||||||
"""Add user content to the log."""
|
"""Add user content to the log."""
|
||||||
|
LOGGER.debug("Adding user content: %s", content)
|
||||||
self.content.append(content)
|
self.content.append(content)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -150,14 +166,24 @@ class ChatLog:
|
|||||||
self, content: AssistantContent
|
self, content: AssistantContent
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add assistant content to the log."""
|
"""Add assistant content to the log."""
|
||||||
|
LOGGER.debug("Adding assistant content: %s", content)
|
||||||
if content.tool_calls is not None:
|
if content.tool_calls is not None:
|
||||||
raise ValueError("Tool calls not allowed")
|
raise ValueError("Tool calls not allowed")
|
||||||
self.content.append(content)
|
self.content.append(content)
|
||||||
|
|
||||||
async def async_add_assistant_content(
|
async def async_add_assistant_content(
|
||||||
self, content: AssistantContent
|
self,
|
||||||
|
content: AssistantContent,
|
||||||
|
/,
|
||||||
|
tool_call_tasks: dict[str, asyncio.Task] | None = None,
|
||||||
) -> AsyncGenerator[ToolResultContent]:
|
) -> AsyncGenerator[ToolResultContent]:
|
||||||
"""Add assistant content."""
|
"""Add assistant content and execute tool calls.
|
||||||
|
|
||||||
|
tool_call_tasks can contains tasks for tool calls that are already in progress.
|
||||||
|
|
||||||
|
This method is an async generator and will yield the tool results as they come in.
|
||||||
|
"""
|
||||||
|
LOGGER.debug("Adding assistant content: %s", content)
|
||||||
self.content.append(content)
|
self.content.append(content)
|
||||||
|
|
||||||
if content.tool_calls is None:
|
if content.tool_calls is None:
|
||||||
@ -166,13 +192,22 @@ class ChatLog:
|
|||||||
if self.llm_api is None:
|
if self.llm_api is None:
|
||||||
raise ValueError("No LLM API configured")
|
raise ValueError("No LLM API configured")
|
||||||
|
|
||||||
|
if tool_call_tasks is None:
|
||||||
|
tool_call_tasks = {}
|
||||||
|
for tool_input in content.tool_calls:
|
||||||
|
if tool_input.id not in tool_call_tasks:
|
||||||
|
tool_call_tasks[tool_input.id] = self.hass.async_create_task(
|
||||||
|
self.llm_api.async_call_tool(tool_input),
|
||||||
|
name=f"llm_tool_{tool_input.id}",
|
||||||
|
)
|
||||||
|
|
||||||
for tool_input in content.tool_calls:
|
for tool_input in content.tool_calls:
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_result = await self.llm_api.async_call_tool(tool_input)
|
tool_result = await tool_call_tasks[tool_input.id]
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
except (HomeAssistantError, vol.Invalid) as e:
|
||||||
tool_result = {"error": type(e).__name__}
|
tool_result = {"error": type(e).__name__}
|
||||||
if str(e):
|
if str(e):
|
||||||
|
@ -15,7 +15,7 @@ from homeassistant.components.conversation import (
|
|||||||
ToolResultContent,
|
ToolResultContent,
|
||||||
async_get_chat_log,
|
async_get_chat_log,
|
||||||
)
|
)
|
||||||
from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY
|
from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import chat_session, llm
|
from homeassistant.helpers import chat_session, llm
|
||||||
@ -63,7 +63,7 @@ async def test_cleanup(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert conversation_id in hass.data[DATA_CHAT_HISTORY]
|
assert conversation_id in hass.data[DATA_CHAT_LOGS]
|
||||||
|
|
||||||
# Set the last updated to be older than the timeout
|
# Set the last updated to be older than the timeout
|
||||||
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
||||||
@ -75,7 +75,7 @@ async def test_cleanup(
|
|||||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert conversation_id not in hass.data[DATA_CHAT_HISTORY]
|
assert conversation_id not in hass.data[DATA_CHAT_LOGS]
|
||||||
|
|
||||||
|
|
||||||
async def test_default_content(
|
async def test_default_content(
|
||||||
@ -279,9 +279,18 @@ async def test_extra_systen_prompt(
|
|||||||
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
assert chat_log.content[0].content.endswith(extra_system_prompt2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prerun_tool_tasks",
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
("mock-tool-call-id",),
|
||||||
|
("mock-tool-call-id", "mock-tool-call-id-2"),
|
||||||
|
],
|
||||||
|
)
|
||||||
async def test_tool_call(
|
async def test_tool_call(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
|
prerun_tool_tasks: tuple[str] | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test using the session tool calling API."""
|
"""Test using the session tool calling API."""
|
||||||
|
|
||||||
@ -316,26 +325,47 @@ async def test_tool_call(
|
|||||||
id="mock-tool-call-id",
|
id="mock-tool-call-id",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "Test Param"},
|
tool_args={"param1": "Test Param"},
|
||||||
)
|
),
|
||||||
|
llm.ToolInput(
|
||||||
|
id="mock-tool-call-id-2",
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "Test Param"},
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_call_tasks = None
|
||||||
|
if prerun_tool_tasks:
|
||||||
|
tool_call_tasks = {
|
||||||
|
tool_call_id: hass.async_create_task(
|
||||||
|
chat_log.llm_api.async_call_tool(content.tool_calls[0]),
|
||||||
|
tool_call_id,
|
||||||
|
)
|
||||||
|
for tool_call_id in prerun_tool_tasks
|
||||||
|
}
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
chat_log.async_add_assistant_content_without_tools(content)
|
chat_log.async_add_assistant_content_without_tools(content)
|
||||||
|
|
||||||
result = None
|
results = [
|
||||||
async for tool_result_content in chat_log.async_add_assistant_content(
|
tool_result_content
|
||||||
content
|
async for tool_result_content in chat_log.async_add_assistant_content(
|
||||||
):
|
content, tool_call_tasks=tool_call_tasks
|
||||||
assert result is None
|
)
|
||||||
result = tool_result_content
|
]
|
||||||
|
|
||||||
assert result == ToolResultContent(
|
assert results[0] == ToolResultContent(
|
||||||
agent_id=mock_conversation_input.agent_id,
|
agent_id=mock_conversation_input.agent_id,
|
||||||
tool_call_id="mock-tool-call-id",
|
tool_call_id="mock-tool-call-id",
|
||||||
tool_result="Test response",
|
tool_result="Test response",
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
)
|
)
|
||||||
|
assert results[1] == ToolResultContent(
|
||||||
|
agent_id=mock_conversation_input.agent_id,
|
||||||
|
tool_call_id="mock-tool-call-id-2",
|
||||||
|
tool_result="Test response",
|
||||||
|
tool_name="test_tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_tool_call_exception(
|
async def test_tool_call_exception(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user