mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 00:37:13 +00:00
Add pipeline intent-progress events based on deltas (#138095)
Add intent progress Assist event
This commit is contained in:
parent
fa3acde684
commit
29c6a2ec13
@ -374,6 +374,7 @@ class PipelineEventType(StrEnum):
|
|||||||
STT_VAD_END = "stt-vad-end"
|
STT_VAD_END = "stt-vad-end"
|
||||||
STT_END = "stt-end"
|
STT_END = "stt-end"
|
||||||
INTENT_START = "intent-start"
|
INTENT_START = "intent-start"
|
||||||
|
INTENT_PROGRESS = "intent-progress"
|
||||||
INTENT_END = "intent-end"
|
INTENT_END = "intent-end"
|
||||||
TTS_START = "tts-start"
|
TTS_START = "tts-start"
|
||||||
TTS_END = "tts-end"
|
TTS_END = "tts-end"
|
||||||
@ -1093,6 +1094,20 @@ class PipelineRun:
|
|||||||
agent_id = conversation.HOME_ASSISTANT_AGENT
|
agent_id = conversation.HOME_ASSISTANT_AGENT
|
||||||
processed_locally = True
|
processed_locally = True
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def chat_log_delta_listener(
|
||||||
|
chat_log: conversation.ChatLog, delta: dict
|
||||||
|
) -> None:
|
||||||
|
"""Handle chat log delta."""
|
||||||
|
self.process_event(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.INTENT_PROGRESS,
|
||||||
|
{
|
||||||
|
"chat_log_delta": delta,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
chat_session.async_get_chat_session(
|
chat_session.async_get_chat_session(
|
||||||
self.hass, user_input.conversation_id
|
self.hass, user_input.conversation_id
|
||||||
@ -1101,6 +1116,7 @@ class PipelineRun:
|
|||||||
self.hass,
|
self.hass,
|
||||||
session,
|
session,
|
||||||
user_input,
|
user_input,
|
||||||
|
chat_log_delta_listener=chat_log_delta_listener,
|
||||||
) as chat_log,
|
) as chat_log,
|
||||||
):
|
):
|
||||||
# It was already handled, create response and add to chat history
|
# It was already handled, create response and add to chat history
|
||||||
|
@ -3,10 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import AsyncGenerator, AsyncIterable, Generator
|
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import asdict, dataclass, field, replace
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
@ -36,6 +36,8 @@ def async_get_chat_log(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
session: chat_session.ChatSession,
|
session: chat_session.ChatSession,
|
||||||
user_input: ConversationInput | None = None,
|
user_input: ConversationInput | None = None,
|
||||||
|
*,
|
||||||
|
chat_log_delta_listener: Callable[[ChatLog, dict], None] | None = None,
|
||||||
) -> Generator[ChatLog]:
|
) -> Generator[ChatLog]:
|
||||||
"""Return chat log for a specific chat session."""
|
"""Return chat log for a specific chat session."""
|
||||||
# If a chat log is already active and it's the requested conversation ID,
|
# If a chat log is already active and it's the requested conversation ID,
|
||||||
@ -43,6 +45,10 @@ def async_get_chat_log(
|
|||||||
if (
|
if (
|
||||||
chat_log := current_chat_log.get()
|
chat_log := current_chat_log.get()
|
||||||
) and chat_log.conversation_id == session.conversation_id:
|
) and chat_log.conversation_id == session.conversation_id:
|
||||||
|
if chat_log_delta_listener is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot attach chat log delta listener unless initial caller"
|
||||||
|
)
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||||
|
|
||||||
@ -59,6 +65,9 @@ def async_get_chat_log(
|
|||||||
else:
|
else:
|
||||||
chat_log = ChatLog(hass, session.conversation_id)
|
chat_log = ChatLog(hass, session.conversation_id)
|
||||||
|
|
||||||
|
if chat_log_delta_listener:
|
||||||
|
chat_log.delta_listener = chat_log_delta_listener
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||||
|
|
||||||
@ -83,6 +92,9 @@ def async_get_chat_log(
|
|||||||
|
|
||||||
session.async_on_cleanup(do_cleanup)
|
session.async_on_cleanup(do_cleanup)
|
||||||
|
|
||||||
|
if chat_log_delta_listener:
|
||||||
|
chat_log.delta_listener = None
|
||||||
|
|
||||||
all_chat_logs[session.conversation_id] = chat_log
|
all_chat_logs[session.conversation_id] = chat_log
|
||||||
|
|
||||||
|
|
||||||
@ -165,6 +177,7 @@ class ChatLog:
|
|||||||
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
|
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
|
||||||
extra_system_prompt: str | None = None
|
extra_system_prompt: str | None = None
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
|
delta_listener: Callable[[ChatLog, dict], None] | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unresponded_tool_results(self) -> bool:
|
def unresponded_tool_results(self) -> bool:
|
||||||
@ -275,6 +288,8 @@ class ChatLog:
|
|||||||
self.llm_api.async_call_tool(tool_call),
|
self.llm_api.async_call_tool(tool_call),
|
||||||
name=f"llm_tool_{tool_call.id}",
|
name=f"llm_tool_{tool_call.id}",
|
||||||
)
|
)
|
||||||
|
if self.delta_listener:
|
||||||
|
self.delta_listener(self, delta) # type: ignore[arg-type]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Starting a new message
|
# Starting a new message
|
||||||
@ -294,10 +309,15 @@ class ChatLog:
|
|||||||
content, tool_call_tasks=tool_call_tasks
|
content, tool_call_tasks=tool_call_tasks
|
||||||
):
|
):
|
||||||
yield tool_result
|
yield tool_result
|
||||||
|
if self.delta_listener:
|
||||||
|
self.delta_listener(self, asdict(tool_result))
|
||||||
|
|
||||||
current_content = delta.get("content") or ""
|
current_content = delta.get("content") or ""
|
||||||
current_tool_calls = delta.get("tool_calls") or []
|
current_tool_calls = delta.get("tool_calls") or []
|
||||||
|
|
||||||
|
if self.delta_listener:
|
||||||
|
self.delta_listener(self, delta) # type: ignore[arg-type]
|
||||||
|
|
||||||
if current_content or current_tool_calls:
|
if current_content or current_tool_calls:
|
||||||
content = AssistantContent(
|
content = AssistantContent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
@ -309,6 +329,8 @@ class ChatLog:
|
|||||||
content, tool_call_tasks=tool_call_tasks
|
content, tool_call_tasks=tool_call_tasks
|
||||||
):
|
):
|
||||||
yield tool_result
|
yield tool_result
|
||||||
|
if self.delta_listener:
|
||||||
|
self.delta_listener(self, asdict(tool_result))
|
||||||
|
|
||||||
async def async_update_llm_data(
|
async def async_update_llm_data(
|
||||||
self,
|
self,
|
||||||
|
@ -9,6 +9,7 @@ from unittest.mock import ANY, Mock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.assist_pipeline.const import (
|
from homeassistant.components.assist_pipeline.const import (
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
SAMPLE_CHANNELS,
|
SAMPLE_CHANNELS,
|
||||||
@ -22,7 +23,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr
|
from homeassistant.helpers import chat_session, device_registry as dr
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
BYTES_ONE_SECOND,
|
BYTES_ONE_SECOND,
|
||||||
@ -2727,3 +2728,62 @@ async def test_stt_cooldown_different_ids(
|
|||||||
|
|
||||||
# Both should start stt
|
# Both should start stt
|
||||||
assert {event_type_1, event_type_2} == {"stt-start"}
|
assert {event_type_1, event_type_2} == {"stt-start"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intent_progress_event(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
) -> None:
|
||||||
|
"""Test intent-progress events from a pipeline are forwarded."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
orig_converse = conversation.async_converse
|
||||||
|
expected_delta_events = [
|
||||||
|
{"chat_log_delta": {"role": "assistant"}},
|
||||||
|
{"chat_log_delta": {"content": "Hello"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
async def mock_delta_stream():
|
||||||
|
"""Mock delta stream."""
|
||||||
|
for d in expected_delta_events:
|
||||||
|
yield d["chat_log_delta"]
|
||||||
|
|
||||||
|
async def mock_converse(**kwargs):
|
||||||
|
"""Mock converse method."""
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(
|
||||||
|
kwargs["hass"], kwargs["conversation_id"]
|
||||||
|
) as session,
|
||||||
|
conversation.async_get_chat_log(hass, session) as chat_log,
|
||||||
|
):
|
||||||
|
async for _content in chat_log.async_add_delta_content_stream(
|
||||||
|
"", mock_delta_stream()
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return await orig_converse(**kwargs)
|
||||||
|
|
||||||
|
with patch("homeassistant.components.conversation.async_converse", mock_converse):
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/run",
|
||||||
|
"start_stage": "intent",
|
||||||
|
"end_stage": "intent",
|
||||||
|
"input": {"text": "Are the lights on?"},
|
||||||
|
"conversation_id": "mock-conversation-id",
|
||||||
|
"device_id": "mock-device-id",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for _ in range(6):
|
||||||
|
msg = await client.receive_json()
|
||||||
|
if msg["event"]["type"] == "intent-progress":
|
||||||
|
events.append(msg["event"]["data"])
|
||||||
|
|
||||||
|
assert events == expected_delta_events
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Test the conversation session."""
|
"""Test the conversation session."""
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from dataclasses import asdict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
@ -524,18 +525,29 @@ async def test_add_delta_content_stream(
|
|||||||
return tool_input.tool_args["param1"]
|
return tool_input.tool_args["param1"]
|
||||||
|
|
||||||
mock_tool.async_call.side_effect = tool_call
|
mock_tool.async_call.side_effect = tool_call
|
||||||
|
expected_delta = []
|
||||||
|
|
||||||
async def stream():
|
async def stream():
|
||||||
"""Yield deltas."""
|
"""Yield deltas."""
|
||||||
for d in deltas:
|
for d in deltas:
|
||||||
yield d
|
yield d
|
||||||
|
expected_delta.append(d)
|
||||||
|
|
||||||
|
captured_deltas = []
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
"homeassistant.helpers.llm.AssistAPI._async_get_tools", return_value=[]
|
||||||
) as mock_get_tools,
|
) as mock_get_tools,
|
||||||
chat_session.async_get_chat_session(hass) as session,
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
async_get_chat_log(
|
||||||
|
hass,
|
||||||
|
session,
|
||||||
|
mock_conversation_input,
|
||||||
|
chat_log_delta_listener=lambda chat_log, delta: captured_deltas.append(
|
||||||
|
delta
|
||||||
|
),
|
||||||
|
) as chat_log,
|
||||||
):
|
):
|
||||||
mock_get_tools.return_value = [mock_tool]
|
mock_get_tools.return_value = [mock_tool]
|
||||||
await chat_log.async_update_llm_data(
|
await chat_log.async_update_llm_data(
|
||||||
@ -545,13 +557,17 @@ async def test_add_delta_content_stream(
|
|||||||
user_llm_prompt=None,
|
user_llm_prompt=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = [
|
results = []
|
||||||
tool_result_content
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
async for tool_result_content in chat_log.async_add_delta_content_stream(
|
"mock-agent-id", stream()
|
||||||
"mock-agent-id", stream()
|
):
|
||||||
)
|
results.append(content)
|
||||||
]
|
|
||||||
|
|
||||||
|
# Interweave the tool results with the source deltas into expected_delta
|
||||||
|
if content.role == "tool_result":
|
||||||
|
expected_delta.append(asdict(content))
|
||||||
|
|
||||||
|
assert captured_deltas == expected_delta
|
||||||
assert results == snapshot
|
assert results == snapshot
|
||||||
assert chat_log.content[2:] == results
|
assert chat_log.content[2:] == results
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user