diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index ef26e1a5a6d..cf9fb4c7212 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -374,6 +374,7 @@ class PipelineEventType(StrEnum): STT_VAD_END = "stt-vad-end" STT_END = "stt-end" INTENT_START = "intent-start" + INTENT_PROGRESS = "intent-progress" INTENT_END = "intent-end" TTS_START = "tts-start" TTS_END = "tts-end" @@ -1093,6 +1094,20 @@ class PipelineRun: agent_id = conversation.HOME_ASSISTANT_AGENT processed_locally = True + @callback + def chat_log_delta_listener( + chat_log: conversation.ChatLog, delta: dict + ) -> None: + """Handle chat log delta.""" + self.process_event( + PipelineEvent( + PipelineEventType.INTENT_PROGRESS, + { + "chat_log_delta": delta, + }, + ) + ) + with ( chat_session.async_get_chat_session( self.hass, user_input.conversation_id @@ -1101,6 +1116,7 @@ class PipelineRun: self.hass, session, user_input, + chat_log_delta_listener=chat_log_delta_listener, ) as chat_log, ): # It was already handled, create response and add to chat history diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index a060a769907..1ee5e9965ab 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -3,10 +3,10 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncGenerator, AsyncIterable, Generator +from collections.abc import AsyncGenerator, AsyncIterable, Callable, Generator from contextlib import contextmanager from contextvars import ContextVar -from dataclasses import dataclass, field, replace +from dataclasses import asdict, dataclass, field, replace import logging from typing import Literal, TypedDict @@ -36,6 +36,8 @@ def async_get_chat_log( hass: HomeAssistant, session: chat_session.ChatSession, user_input: ConversationInput | None = None, + *, + chat_log_delta_listener: Callable[[ChatLog, dict], None] | None = None, ) -> Generator[ChatLog]: """Return chat log for a specific chat session.""" # If a chat log is already active and it's the requested conversation ID, @@ -43,6 +45,10 @@ def async_get_chat_log( if ( chat_log := current_chat_log.get() ) and chat_log.conversation_id == session.conversation_id: + if chat_log_delta_listener is not None: + raise RuntimeError( + "Cannot attach chat log delta listener unless initial caller" + ) if user_input is not None: chat_log.async_add_user_content(UserContent(content=user_input.text)) @@ -59,6 +65,9 @@ def async_get_chat_log( else: chat_log = ChatLog(hass, session.conversation_id) + if chat_log_delta_listener: + chat_log.delta_listener = chat_log_delta_listener + if user_input is not None: chat_log.async_add_user_content(UserContent(content=user_input.text)) @@ -83,6 +92,9 @@ def async_get_chat_log( session.async_on_cleanup(do_cleanup) + if chat_log_delta_listener: + chat_log.delta_listener = None + all_chat_logs[session.conversation_id] = chat_log @@ -165,6 +177,7 @@ class ChatLog: content: list[Content] = field(default_factory=lambda: [SystemContent(content="")]) extra_system_prompt: str | None = None llm_api: llm.APIInstance | None = None + delta_listener: Callable[[ChatLog, dict], None] | None = None @property def unresponded_tool_results(self) -> bool: @@ -275,6 +288,8 @@ class ChatLog: self.llm_api.async_call_tool(tool_call), name=f"llm_tool_{tool_call.id}", ) + if self.delta_listener: + self.delta_listener(self, delta) # type: ignore[arg-type] continue # Starting a new message @@ -294,10 +309,15 @@ class ChatLog: content, tool_call_tasks=tool_call_tasks ): yield tool_result + if self.delta_listener: + self.delta_listener(self, asdict(tool_result)) current_content = delta.get("content") or "" current_tool_calls = delta.get("tool_calls") or [] + if self.delta_listener: + self.delta_listener(self, delta) # type: ignore[arg-type] + if current_content or current_tool_calls: content = AssistantContent( agent_id=agent_id, @@ -309,6 +329,8 @@ class ChatLog: content, tool_call_tasks=tool_call_tasks ): yield tool_result + if self.delta_listener: + self.delta_listener(self, asdict(tool_result)) async def async_update_llm_data( self, diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 2cd56f094dd..f856bbe7f61 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -9,6 +9,7 @@ from unittest.mock import ANY, Mock, patch import pytest from syrupy.assertion import SnapshotAssertion +from homeassistant.components import conversation from homeassistant.components.assist_pipeline.const import ( DOMAIN, SAMPLE_CHANNELS, @@ -22,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import ( ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr +from homeassistant.helpers import chat_session, device_registry as dr from .conftest import ( BYTES_ONE_SECOND, @@ -2727,3 +2728,62 @@ async def test_stt_cooldown_different_ids( # Both should start stt assert {event_type_1, event_type_2} == {"stt-start"} + + +async def test_intent_progress_event( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components, +) -> None: + """Test intent-progress events from a pipeline are forwarded.""" + client = await hass_ws_client(hass) + + orig_converse = conversation.async_converse + expected_delta_events = [ + {"chat_log_delta": {"role": "assistant"}}, + {"chat_log_delta": {"content": "Hello"}}, + ] + + async def mock_delta_stream(): + """Mock delta stream.""" + for d in expected_delta_events: + yield d["chat_log_delta"] + + async def mock_converse(**kwargs): + """Mock converse method.""" + with ( + chat_session.async_get_chat_session( + kwargs["hass"], kwargs["conversation_id"] + ) as session, + conversation.async_get_chat_log(hass, session) as chat_log, + ): + async for _content in chat_log.async_add_delta_content_stream( + "", mock_delta_stream() + ): + pass + + return await orig_converse(**kwargs) + + with patch("homeassistant.components.conversation.async_converse", mock_converse): + await client.send_json_auto_id( + { + "type": "assist_pipeline/run", + "start_stage": "intent", + "end_stage": "intent", + "input": {"text": "Are the lights on?"}, + "conversation_id": "mock-conversation-id", + "device_id": "mock-device-id", + } + ) + + # result + msg = await client.receive_json() + assert msg["success"] + + events = [] + for _ in range(6): + msg = await client.receive_json() + if msg["event"]["type"] == "intent-progress": + events.append(msg["event"]["data"]) + + assert events == expected_delta_events diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index 0c11d19aab2..a4dc9b819c1 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -1,6 +1,7 @@ """Test the conversation session.""" from collections.abc import Generator +from dataclasses import asdict from datetime import timedelta from unittest.mock import AsyncMock, Mock, patch @@ -524,18 +525,29 @@ async def test_add_delta_content_stream( return tool_input.tool_args["param1"] mock_tool.async_call.side_effect = tool_call + expected_delta = [] async def stream(): """Yield deltas.""" for d in deltas: yield d + expected_delta.append(d) + + captured_deltas = [] with ( patch( "homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[] ) as mock_get_tools, chat_session.async_get_chat_session(hass) as session, - async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + async_get_chat_log( + hass, + session, + mock_conversation_input, + chat_log_delta_listener=lambda chat_log, delta: captured_deltas.append( + delta + ), + ) as chat_log, ): mock_get_tools.return_value = [mock_tool] await chat_log.async_update_llm_data( @@ -545,13 +557,17 @@ async def test_add_delta_content_stream( user_llm_prompt=None, ) - results = [ - tool_result_content - async for tool_result_content in chat_log.async_add_delta_content_stream( - "mock-agent-id", stream() - ) - ] + results = [] + async for content in chat_log.async_add_delta_content_stream( + "mock-agent-id", stream() + ): + results.append(content) + # Interweave the tool results with the source deltas into expected_delta + if content.role == "tool_result": + expected_delta.append(asdict(content)) + + assert captured_deltas == expected_delta assert results == snapshot assert chat_log.content[2:] == results