diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 94e2b04d7ae..ef26e1a5a6d 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1093,16 +1093,18 @@ class PipelineRun: agent_id = conversation.HOME_ASSISTANT_AGENT processed_locally = True - # It was already handled, create response and add to chat history - if intent_response is not None: - with ( - chat_session.async_get_chat_session( - self.hass, user_input.conversation_id - ) as session, - conversation.async_get_chat_log( - self.hass, session, user_input - ) as chat_log, - ): + with ( + chat_session.async_get_chat_session( + self.hass, user_input.conversation_id + ) as session, + conversation.async_get_chat_log( + self.hass, + session, + user_input, + ) as chat_log, + ): + # It was already handled, create response and add to chat history + if intent_response is not None: speech: str = intent_response.speech.get("plain", {}).get( "speech", "" ) @@ -1117,21 +1119,21 @@ class PipelineRun: conversation_id=session.conversation_id, ) - else: - # Fall back to pipeline conversation agent - conversation_result = await conversation.async_converse( - hass=self.hass, - text=user_input.text, - conversation_id=user_input.conversation_id, - device_id=user_input.device_id, - context=user_input.context, - language=user_input.language, - agent_id=user_input.agent_id, - extra_system_prompt=user_input.extra_system_prompt, - ) - speech = conversation_result.response.speech.get("plain", {}).get( - "speech", "" - ) + else: + # Fall back to pipeline conversation agent + conversation_result = await conversation.async_converse( + hass=self.hass, + text=user_input.text, + conversation_id=user_input.conversation_id, + device_id=user_input.device_id, + context=user_input.context, + language=user_input.language, + agent_id=user_input.agent_id, + extra_system_prompt=user_input.extra_system_prompt, + ) + speech = conversation_result.response.speech.get("plain", {}).get( + "speech", "" + ) except Exception as src_error: _LOGGER.exception("Unexpected error during intent recognition") diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index ad7a9d0ce9e..e4ff1904e7c 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -1,9 +1,11 @@ -"""Conversation history.""" +"""Conversation chat log.""" from __future__ import annotations +import asyncio from collections.abc import AsyncGenerator, Generator from contextlib import contextmanager +from contextvars import ContextVar from dataclasses import dataclass, field, replace import logging @@ -19,10 +21,14 @@ from . import trace from .const import DOMAIN from .models import ConversationInput, ConversationResult -DATA_CHAT_HISTORY: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_log") +DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs") LOGGER = logging.getLogger(__name__) +current_chat_log: ContextVar[ChatLog | None] = ContextVar( + "current_chat_log", default=None +) + @contextmanager def async_get_chat_log( @@ -31,41 +37,50 @@ def async_get_chat_log( user_input: ConversationInput | None = None, ) -> Generator[ChatLog]: """Return chat log for a specific chat session.""" - all_history = hass.data.get(DATA_CHAT_HISTORY) - if all_history is None: - all_history = {} - hass.data[DATA_CHAT_HISTORY] = all_history + if chat_log := current_chat_log.get(): + # If a chat log is already active and it's the requested conversation ID, + # return that. We won't update the last updated time in this case. + if chat_log.conversation_id == session.conversation_id: + yield chat_log + return - history = all_history.get(session.conversation_id) + all_chat_logs = hass.data.get(DATA_CHAT_LOGS) + if all_chat_logs is None: + all_chat_logs = {} + hass.data[DATA_CHAT_LOGS] = all_chat_logs - if history: - history = replace(history, content=history.content.copy()) + chat_log = all_chat_logs.get(session.conversation_id) + + if chat_log: + chat_log = replace(chat_log, content=chat_log.content.copy()) else: - history = ChatLog(hass, session.conversation_id) + chat_log = ChatLog(hass, session.conversation_id) if user_input is not None: - history.async_add_user_content(UserContent(content=user_input.text)) + chat_log.async_add_user_content(UserContent(content=user_input.text)) - last_message = history.content[-1] + last_message = chat_log.content[-1] - yield history + token = current_chat_log.set(chat_log) + yield chat_log + current_chat_log.reset(token) - if history.content[-1] is last_message: + if chat_log.content[-1] is last_message: LOGGER.debug( - "History opened but no assistant message was added, ignoring update" + "Chat Log opened but no assistant message was added, ignoring update" ) return - if session.conversation_id not in all_history: + if session.conversation_id not in all_chat_logs: @callback def do_cleanup() -> None: """Handle cleanup.""" - all_history.pop(session.conversation_id) + all_chat_logs.pop(session.conversation_id) session.async_on_cleanup(do_cleanup) - all_history[session.conversation_id] = history + all_chat_logs[session.conversation_id] = chat_log class ConverseError(HomeAssistantError): @@ -112,7 +127,7 @@ class AssistantContent: role: str = field(init=False, default="assistant") agent_id: str - content: str + content: str | None = None tool_calls: list[llm.ToolInput] | None = None @@ -143,6 +158,7 @@ class ChatLog: @callback def async_add_user_content(self, content: UserContent) -> None: """Add user content to the log.""" + LOGGER.debug("Adding user content: %s", content) self.content.append(content) @callback @@ -150,14 +166,24 @@ class ChatLog: self, content: AssistantContent ) -> None: """Add assistant content to the log.""" + LOGGER.debug("Adding assistant content: %s", content) if content.tool_calls is not None: raise ValueError("Tool calls not allowed") self.content.append(content) async def async_add_assistant_content( - self, content: AssistantContent + self, + content: AssistantContent, + /, + tool_call_tasks: dict[str, asyncio.Task] | None = None, ) -> AsyncGenerator[ToolResultContent]: - """Add assistant content.""" + """Add assistant content and execute tool calls. + + tool_call_tasks can contains tasks for tool calls that are already in progress. + + This method is an async generator and will yield the tool results as they come in. + """ + LOGGER.debug("Adding assistant content: %s", content) self.content.append(content) if content.tool_calls is None: @@ -166,13 +192,22 @@ class ChatLog: if self.llm_api is None: raise ValueError("No LLM API configured") + if tool_call_tasks is None: + tool_call_tasks = {} + for tool_input in content.tool_calls: + if tool_input.id not in tool_call_tasks: + tool_call_tasks[tool_input.id] = self.hass.async_create_task( + self.llm_api.async_call_tool(tool_input), + name=f"llm_tool_{tool_input.id}", + ) + for tool_input in content.tool_calls: LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args ) try: - tool_result = await self.llm_api.async_call_tool(tool_input) + tool_result = await tool_call_tasks[tool_input.id] except (HomeAssistantError, vol.Invalid) as e: tool_result = {"error": type(e).__name__} if str(e): diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index c22a90e6928..1f659b8005e 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -15,7 +15,7 @@ from homeassistant.components.conversation import ( ToolResultContent, async_get_chat_log, ) -from homeassistant.components.conversation.chat_log import DATA_CHAT_HISTORY +from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import chat_session, llm @@ -63,7 +63,7 @@ async def test_cleanup( ) ) - assert conversation_id in hass.data[DATA_CHAT_HISTORY] + assert conversation_id in hass.data[DATA_CHAT_LOGS] # Set the last updated to be older than the timeout hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = ( @@ -75,7 +75,7 @@ async def test_cleanup( dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1), ) - assert conversation_id not in hass.data[DATA_CHAT_HISTORY] + assert conversation_id not in hass.data[DATA_CHAT_LOGS] async def test_default_content( @@ -279,9 +279,18 @@ async def test_extra_systen_prompt( assert chat_log.content[0].content.endswith(extra_system_prompt2) +@pytest.mark.parametrize( + "prerun_tool_tasks", + [ + None, + ("mock-tool-call-id",), + ("mock-tool-call-id", "mock-tool-call-id-2"), + ], +) async def test_tool_call( hass: HomeAssistant, mock_conversation_input: ConversationInput, + prerun_tool_tasks: tuple[str] | None, ) -> None: """Test using the session tool calling API.""" @@ -316,26 +325,47 @@ async def test_tool_call( id="mock-tool-call-id", tool_name="test_tool", tool_args={"param1": "Test Param"}, - ) + ), + llm.ToolInput( + id="mock-tool-call-id-2", + tool_name="test_tool", + tool_args={"param1": "Test Param"}, + ), ], ) + tool_call_tasks = None + if prerun_tool_tasks: + tool_call_tasks = { + tool_call_id: hass.async_create_task( + chat_log.llm_api.async_call_tool(content.tool_calls[0]), + tool_call_id, + ) + for tool_call_id in prerun_tool_tasks + } + with pytest.raises(ValueError): chat_log.async_add_assistant_content_without_tools(content) - result = None - async for tool_result_content in chat_log.async_add_assistant_content( - content - ): - assert result is None - result = tool_result_content + results = [ + tool_result_content + async for tool_result_content in chat_log.async_add_assistant_content( + content, tool_call_tasks=tool_call_tasks + ) + ] - assert result == ToolResultContent( - agent_id=mock_conversation_input.agent_id, - tool_call_id="mock-tool-call-id", - tool_result="Test response", - tool_name="test_tool", - ) + assert results[0] == ToolResultContent( + agent_id=mock_conversation_input.agent_id, + tool_call_id="mock-tool-call-id", + tool_result="Test response", + tool_name="test_tool", + ) + assert results[1] == ToolResultContent( + agent_id=mock_conversation_input.agent_id, + tool_call_id="mock-tool-call-id-2", + tool_result="Test response", + tool_name="test_tool", + ) async def test_tool_call_exception(