mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 16:27:08 +00:00
Add pipeline intent-progress events based on deltas (#138095)
Add intent progress Assist event
This commit is contained in:
parent
fa3acde684
commit
29c6a2ec13
@ -374,6 +374,7 @@ class PipelineEventType(StrEnum):
|
||||
STT_VAD_END = "stt-vad-end"
|
||||
STT_END = "stt-end"
|
||||
INTENT_START = "intent-start"
|
||||
INTENT_PROGRESS = "intent-progress"
|
||||
INTENT_END = "intent-end"
|
||||
TTS_START = "tts-start"
|
||||
TTS_END = "tts-end"
|
||||
@ -1093,6 +1094,20 @@ class PipelineRun:
|
||||
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||
processed_locally = True
|
||||
|
||||
@callback
|
||||
def chat_log_delta_listener(
|
||||
chat_log: conversation.ChatLog, delta: dict
|
||||
) -> None:
|
||||
"""Handle chat log delta."""
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.INTENT_PROGRESS,
|
||||
{
|
||||
"chat_log_delta": delta,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
self.hass, user_input.conversation_id
|
||||
@ -1101,6 +1116,7 @@ class PipelineRun:
|
||||
self.hass,
|
||||
session,
|
||||
user_input,
|
||||
chat_log_delta_listener=chat_log_delta_listener,
|
||||
) as chat_log,
|
||||
):
|
||||
# It was already handled, create response and add to chat history
|
||||
|
@ -3,10 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Generator
|
||||
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field, replace
|
||||
from dataclasses import asdict, dataclass, field, replace
|
||||
import logging
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
@ -36,6 +36,8 @@ def async_get_chat_log(
|
||||
hass: HomeAssistant,
|
||||
session: chat_session.ChatSession,
|
||||
user_input: ConversationInput | None = None,
|
||||
*,
|
||||
chat_log_delta_listener: Callable[[ChatLog, dict], None] | None = None,
|
||||
) -> Generator[ChatLog]:
|
||||
"""Return chat log for a specific chat session."""
|
||||
# If a chat log is already active and it's the requested conversation ID,
|
||||
@ -43,6 +45,10 @@ def async_get_chat_log(
|
||||
if (
|
||||
chat_log := current_chat_log.get()
|
||||
) and chat_log.conversation_id == session.conversation_id:
|
||||
if chat_log_delta_listener is not None:
|
||||
raise RuntimeError(
|
||||
"Cannot attach chat log delta listener unless initial caller"
|
||||
)
|
||||
if user_input is not None:
|
||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
@ -59,6 +65,9 @@ def async_get_chat_log(
|
||||
else:
|
||||
chat_log = ChatLog(hass, session.conversation_id)
|
||||
|
||||
if chat_log_delta_listener:
|
||||
chat_log.delta_listener = chat_log_delta_listener
|
||||
|
||||
if user_input is not None:
|
||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
@ -83,6 +92,9 @@ def async_get_chat_log(
|
||||
|
||||
session.async_on_cleanup(do_cleanup)
|
||||
|
||||
if chat_log_delta_listener:
|
||||
chat_log.delta_listener = None
|
||||
|
||||
all_chat_logs[session.conversation_id] = chat_log
|
||||
|
||||
|
||||
@ -165,6 +177,7 @@ class ChatLog:
|
||||
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
|
||||
extra_system_prompt: str | None = None
|
||||
llm_api: llm.APIInstance | None = None
|
||||
delta_listener: Callable[[ChatLog, dict], None] | None = None
|
||||
|
||||
@property
|
||||
def unresponded_tool_results(self) -> bool:
|
||||
@ -275,6 +288,8 @@ class ChatLog:
|
||||
self.llm_api.async_call_tool(tool_call),
|
||||
name=f"llm_tool_{tool_call.id}",
|
||||
)
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, delta) # type: ignore[arg-type]
|
||||
continue
|
||||
|
||||
# Starting a new message
|
||||
@ -294,10 +309,15 @@ class ChatLog:
|
||||
content, tool_call_tasks=tool_call_tasks
|
||||
):
|
||||
yield tool_result
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, asdict(tool_result))
|
||||
|
||||
current_content = delta.get("content") or ""
|
||||
current_tool_calls = delta.get("tool_calls") or []
|
||||
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, delta) # type: ignore[arg-type]
|
||||
|
||||
if current_content or current_tool_calls:
|
||||
content = AssistantContent(
|
||||
agent_id=agent_id,
|
||||
@ -309,6 +329,8 @@ class ChatLog:
|
||||
content, tool_call_tasks=tool_call_tasks
|
||||
):
|
||||
yield tool_result
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, asdict(tool_result))
|
||||
|
||||
async def async_update_llm_data(
|
||||
self,
|
||||
|
@ -9,6 +9,7 @@ from unittest.mock import ANY, Mock, patch
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
DOMAIN,
|
||||
SAMPLE_CHANNELS,
|
||||
@ -22,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers import chat_session, device_registry as dr
|
||||
|
||||
from .conftest import (
|
||||
BYTES_ONE_SECOND,
|
||||
@ -2727,3 +2728,62 @@ async def test_stt_cooldown_different_ids(
|
||||
|
||||
# Both should start stt
|
||||
assert {event_type_1, event_type_2} == {"stt-start"}
|
||||
|
||||
|
||||
async def test_intent_progress_event(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
) -> None:
|
||||
"""Test intent-progress events from a pipeline are forwarded."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
orig_converse = conversation.async_converse
|
||||
expected_delta_events = [
|
||||
{"chat_log_delta": {"role": "assistant"}},
|
||||
{"chat_log_delta": {"content": "Hello"}},
|
||||
]
|
||||
|
||||
async def mock_delta_stream():
|
||||
"""Mock delta stream."""
|
||||
for d in expected_delta_events:
|
||||
yield d["chat_log_delta"]
|
||||
|
||||
async def mock_converse(**kwargs):
|
||||
"""Mock converse method."""
|
||||
with (
|
||||
chat_session.async_get_chat_session(
|
||||
kwargs["hass"], kwargs["conversation_id"]
|
||||
) as session,
|
||||
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
async for _content in chat_log.async_add_delta_content_stream(
|
||||
"", mock_delta_stream()
|
||||
):
|
||||
pass
|
||||
|
||||
return await orig_converse(**kwargs)
|
||||
|
||||
with patch("homeassistant.components.conversation.async_converse", mock_converse):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "intent",
|
||||
"end_stage": "intent",
|
||||
"input": {"text": "Are the lights on?"},
|
||||
"conversation_id": "mock-conversation-id",
|
||||
"device_id": "mock-device-id",
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
events = []
|
||||
for _ in range(6):
|
||||
msg = await client.receive_json()
|
||||
if msg["event"]["type"] == "intent-progress":
|
||||
events.append(msg["event"]["data"])
|
||||
|
||||
assert events == expected_delta_events
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test the conversation session."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
@ -524,18 +525,29 @@ async def test_add_delta_content_stream(
|
||||
return tool_input.tool_args["param1"]
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
expected_delta = []
|
||||
|
||||
async def stream():
|
||||
"""Yield deltas."""
|
||||
for d in deltas:
|
||||
yield d
|
||||
expected_delta.append(d)
|
||||
|
||||
captured_deltas = []
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||
) as mock_get_tools,
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
async_get_chat_log(
|
||||
hass,
|
||||
session,
|
||||
mock_conversation_input,
|
||||
chat_log_delta_listener=lambda chat_log, delta: captured_deltas.append(
|
||||
delta
|
||||
),
|
||||
) as chat_log,
|
||||
):
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
await chat_log.async_update_llm_data(
|
||||
@ -545,13 +557,17 @@ async def test_add_delta_content_stream(
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
||||
results = [
|
||||
tool_result_content
|
||||
async for tool_result_content in chat_log.async_add_delta_content_stream(
|
||||
"mock-agent-id", stream()
|
||||
)
|
||||
]
|
||||
results = []
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
"mock-agent-id", stream()
|
||||
):
|
||||
results.append(content)
|
||||
|
||||
# Interweave the tool results with the source deltas into expected_delta
|
||||
if content.role == "tool_result":
|
||||
expected_delta.append(asdict(content))
|
||||
|
||||
assert captured_deltas == expected_delta
|
||||
assert results == snapshot
|
||||
assert chat_log.content[2:] == results
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user