Add pipeline intent-progress events based on deltas (#138095)

Add intent progress Assist event
This commit is contained in:
Paulus Schoutsen 2025-02-09 21:09:52 -05:00 committed by GitHub
parent fa3acde684
commit 29c6a2ec13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 124 additions and 10 deletions

View File

@ -374,6 +374,7 @@ class PipelineEventType(StrEnum):
STT_VAD_END = "stt-vad-end"
STT_END = "stt-end"
INTENT_START = "intent-start"
INTENT_PROGRESS = "intent-progress"
INTENT_END = "intent-end"
TTS_START = "tts-start"
TTS_END = "tts-end"
@ -1093,6 +1094,20 @@ class PipelineRun:
agent_id = conversation.HOME_ASSISTANT_AGENT
processed_locally = True
@callback
def chat_log_delta_listener(
chat_log: conversation.ChatLog, delta: dict
) -> None:
"""Handle chat log delta."""
self.process_event(
PipelineEvent(
PipelineEventType.INTENT_PROGRESS,
{
"chat_log_delta": delta,
},
)
)
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
@ -1101,6 +1116,7 @@ class PipelineRun:
self.hass,
session,
user_input,
chat_log_delta_listener=chat_log_delta_listener,
) as chat_log,
):
# It was already handled, create response and add to chat history

View File

@ -3,10 +3,10 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, AsyncIterable, Generator
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field, replace
from dataclasses import asdict, dataclass, field, replace
import logging
from typing import Literal, TypedDict
@ -36,6 +36,8 @@ def async_get_chat_log(
hass: HomeAssistant,
session: chat_session.ChatSession,
user_input: ConversationInput | None = None,
*,
chat_log_delta_listener: Callable[[ChatLog, dict], None] | None = None,
) -> Generator[ChatLog]:
"""Return chat log for a specific chat session."""
# If a chat log is already active and it's the requested conversation ID,
@ -43,6 +45,10 @@ def async_get_chat_log(
if (
chat_log := current_chat_log.get()
) and chat_log.conversation_id == session.conversation_id:
if chat_log_delta_listener is not None:
raise RuntimeError(
"Cannot attach chat log delta listener unless initial caller"
)
if user_input is not None:
chat_log.async_add_user_content(UserContent(content=user_input.text))
@ -59,6 +65,9 @@ def async_get_chat_log(
else:
chat_log = ChatLog(hass, session.conversation_id)
if chat_log_delta_listener:
chat_log.delta_listener = chat_log_delta_listener
if user_input is not None:
chat_log.async_add_user_content(UserContent(content=user_input.text))
@ -83,6 +92,9 @@ def async_get_chat_log(
session.async_on_cleanup(do_cleanup)
if chat_log_delta_listener:
chat_log.delta_listener = None
all_chat_logs[session.conversation_id] = chat_log
@ -165,6 +177,7 @@ class ChatLog:
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
delta_listener: Callable[[ChatLog, dict], None] | None = None
@property
def unresponded_tool_results(self) -> bool:
@ -275,6 +288,8 @@ class ChatLog:
self.llm_api.async_call_tool(tool_call),
name=f"llm_tool_{tool_call.id}",
)
if self.delta_listener:
self.delta_listener(self, delta) # type: ignore[arg-type]
continue
# Starting a new message
@ -294,10 +309,15 @@ class ChatLog:
content, tool_call_tasks=tool_call_tasks
):
yield tool_result
if self.delta_listener:
self.delta_listener(self, asdict(tool_result))
current_content = delta.get("content") or ""
current_tool_calls = delta.get("tool_calls") or []
if self.delta_listener:
self.delta_listener(self, delta) # type: ignore[arg-type]
if current_content or current_tool_calls:
content = AssistantContent(
agent_id=agent_id,
@ -309,6 +329,8 @@ class ChatLog:
content, tool_call_tasks=tool_call_tasks
):
yield tool_result
if self.delta_listener:
self.delta_listener(self, asdict(tool_result))
async def async_update_llm_data(
self,

View File

@ -9,6 +9,7 @@ from unittest.mock import ANY, Mock, patch
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.components.assist_pipeline.const import (
DOMAIN,
SAMPLE_CHANNELS,
@ -22,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers import chat_session, device_registry as dr
from .conftest import (
BYTES_ONE_SECOND,
@ -2727,3 +2728,62 @@ async def test_stt_cooldown_different_ids(
# Both should start stt
assert {event_type_1, event_type_2} == {"stt-start"}
async def test_intent_progress_event(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
) -> None:
"""Test intent-progress events from a pipeline are forwarded."""
client = await hass_ws_client(hass)
orig_converse = conversation.async_converse
expected_delta_events = [
{"chat_log_delta": {"role": "assistant"}},
{"chat_log_delta": {"content": "Hello"}},
]
async def mock_delta_stream():
"""Mock delta stream."""
for d in expected_delta_events:
yield d["chat_log_delta"]
async def mock_converse(**kwargs):
"""Mock converse method."""
with (
chat_session.async_get_chat_session(
kwargs["hass"], kwargs["conversation_id"]
) as session,
conversation.async_get_chat_log(hass, session) as chat_log,
):
async for _content in chat_log.async_add_delta_content_stream(
"", mock_delta_stream()
):
pass
return await orig_converse(**kwargs)
with patch("homeassistant.components.conversation.async_converse", mock_converse):
await client.send_json_auto_id(
{
"type": "assist_pipeline/run",
"start_stage": "intent",
"end_stage": "intent",
"input": {"text": "Are the lights on?"},
"conversation_id": "mock-conversation-id",
"device_id": "mock-device-id",
}
)
# result
msg = await client.receive_json()
assert msg["success"]
events = []
for _ in range(6):
msg = await client.receive_json()
if msg["event"]["type"] == "intent-progress":
events.append(msg["event"]["data"])
assert events == expected_delta_events

View File

@ -1,6 +1,7 @@
"""Test the conversation session."""
from collections.abc import Generator
from dataclasses import asdict
from datetime import timedelta
from unittest.mock import AsyncMock, Mock, patch
@ -524,18 +525,29 @@ async def test_add_delta_content_stream(
return tool_input.tool_args["param1"]
mock_tool.async_call.side_effect = tool_call
expected_delta = []
async def stream():
"""Yield deltas."""
for d in deltas:
yield d
expected_delta.append(d)
captured_deltas = []
with (
patch(
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
) as mock_get_tools,
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
async_get_chat_log(
hass,
session,
mock_conversation_input,
chat_log_delta_listener=lambda chat_log, delta: captured_deltas.append(
delta
),
) as chat_log,
):
mock_get_tools.return_value = [mock_tool]
await chat_log.async_update_llm_data(
@ -545,13 +557,17 @@ async def test_add_delta_content_stream(
user_llm_prompt=None,
)
results = [
tool_result_content
async for tool_result_content in chat_log.async_add_delta_content_stream(
"mock-agent-id", stream()
)
]
results = []
async for content in chat_log.async_add_delta_content_stream(
"mock-agent-id", stream()
):
results.append(content)
# Interweave the tool results with the source deltas into expected_delta
if content.role == "tool_result":
expected_delta.append(asdict(content))
assert captured_deltas == expected_delta
assert results == snapshot
assert chat_log.content[2:] == results