Add websocket command to interact with chat logs

This commit is contained in:
Paulus Schoutsen
2025-10-25 22:36:31 -04:00
parent 4c9810a10e
commit e1383d30e7
6 changed files with 434 additions and 8 deletions

View File

@@ -20,10 +20,13 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from . import trace
from .const import ChatLogEventType
from .models import ConversationInput, ConversationResult
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
SUBSCRIPTIONS: HassKey[list[Callable[[ChatLogEventType, dict[str, Any]], None]]] = (
HassKey("conversation_chat_log_subscriptions")
)
LOGGER = logging.getLogger(__name__)
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
@@ -31,6 +34,37 @@ current_chat_log: ContextVar[ChatLog | None] = ContextVar(
)
@callback
def async_subscribe_chat_logs(
hass: HomeAssistant,
callback_func: Callable[[ChatLogEventType, dict[str, Any]], None],
) -> Callable[[], None]:
"""Subscribe to all chat logs."""
subscriptions = hass.data.get(SUBSCRIPTIONS)
if subscriptions is None:
subscriptions = []
hass.data[SUBSCRIPTIONS] = subscriptions
subscriptions.append(callback_func)
@callback
def unsubscribe() -> None:
"""Unsubscribe from chat logs."""
subscriptions.remove(callback_func)
return unsubscribe
@callback
def _async_notify_subscribers(
hass: HomeAssistant, event_type: ChatLogEventType, data: dict[str, Any]
) -> None:
"""Notify subscribers of a chat log event."""
if subscriptions := hass.data.get(SUBSCRIPTIONS):
for callback_func in subscriptions:
callback_func(event_type, data)
@contextmanager
def async_get_chat_log(
hass: HomeAssistant,
@@ -63,6 +97,8 @@ def async_get_chat_log(
all_chat_logs = {}
hass.data[DATA_CHAT_LOGS] = all_chat_logs
is_new_log = session.conversation_id not in all_chat_logs
if chat_log := all_chat_logs.get(session.conversation_id):
chat_log = replace(chat_log, content=chat_log.content.copy())
else:
@@ -71,6 +107,12 @@ def async_get_chat_log(
if chat_log_delta_listener:
chat_log.delta_listener = chat_log_delta_listener
# Fire CREATED event for new chat logs before any content is added
if is_new_log:
_async_notify_subscribers(
hass, ChatLogEventType.CREATED, {"chat_log": chat_log.as_dict()}
)
if user_input is not None:
chat_log.async_add_user_content(UserContent(content=user_input.text))
@@ -84,14 +126,26 @@ def async_get_chat_log(
LOGGER.debug(
"Chat Log opened but no assistant message was added, ignoring update"
)
# If this was a new log but nothing was added, fire DELETED to clean up
if is_new_log:
_async_notify_subscribers(
hass,
ChatLogEventType.DELETED,
{"conversation_id": session.conversation_id},
)
return
if session.conversation_id not in all_chat_logs:
if is_new_log:
@callback
def do_cleanup() -> None:
"""Handle cleanup."""
all_chat_logs.pop(session.conversation_id)
_async_notify_subscribers(
hass,
ChatLogEventType.DELETED,
{"conversation_id": session.conversation_id},
)
session.async_on_cleanup(do_cleanup)
@@ -100,6 +154,13 @@ def async_get_chat_log(
all_chat_logs[session.conversation_id] = chat_log
# For new logs, CREATED was already fired before content was added
# For existing logs, fire UPDATED
if not is_new_log:
_async_notify_subscribers(
hass, ChatLogEventType.UPDATED, {"chat_log": chat_log.as_dict()}
)
class ConverseError(HomeAssistantError):
"""Error during initialization of conversation.
@@ -130,6 +191,10 @@ class SystemContent:
role: Literal["system"] = field(init=False, default="system")
content: str
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
return {"role": self.role, "content": self.content}
@dataclass(frozen=True)
class UserContent:
@@ -139,6 +204,15 @@ class UserContent:
content: str
attachments: list[Attachment] | None = field(default=None)
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
result: dict[str, Any] = {"role": self.role, "content": self.content}
if self.attachments:
result["attachments"] = [
attachment.as_dict() for attachment in self.attachments
]
return result
@dataclass(frozen=True)
class Attachment:
@@ -153,6 +227,14 @@ class Attachment:
path: Path
"""Path to the attachment on disk."""
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the attachment."""
return {
"media_content_id": self.media_content_id,
"mime_type": self.mime_type,
"path": str(self.path),
}
@dataclass(frozen=True)
class AssistantContent:
@@ -165,6 +247,17 @@ class AssistantContent:
tool_calls: list[llm.ToolInput] | None = None
native: Any = None
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
result: dict[str, Any] = {"role": self.role, "agent_id": self.agent_id}
if self.content:
result["content"] = self.content
if self.thinking_content:
result["thinking_content"] = self.thinking_content
if self.tool_calls:
result["tool_calls"] = self.tool_calls
return result
@dataclass(frozen=True)
class ToolResultContent:
@@ -176,6 +269,16 @@ class ToolResultContent:
tool_name: str
tool_result: JsonObjectType
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
return {
"role": self.role,
"agent_id": self.agent_id,
"tool_call_id": self.tool_call_id,
"tool_name": self.tool_name,
"tool_result": self.tool_result,
}
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
@@ -211,6 +314,13 @@ class ChatLog:
delta_listener: Callable[[ChatLog, dict], None] | None = None
llm_input_provided_index = 0
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the chat log."""
return {
"conversation_id": self.conversation_id,
"continue_conversation": self.continue_conversation,
}
@property
def continue_conversation(self) -> bool:
"""Return whether the conversation should continue."""
@@ -241,6 +351,11 @@ class ChatLog:
"""Add user content to the log."""
LOGGER.debug("Adding user content: %s", content)
self.content.append(content)
_async_notify_subscribers(
self.hass,
ChatLogEventType.CONTENT_ADDED,
{"conversation_id": self.conversation_id, "content": content.as_dict()},
)
@callback
def async_add_assistant_content_without_tools(
@@ -259,6 +374,11 @@ class ChatLog:
):
raise ValueError("Non-external tool calls not allowed")
self.content.append(content)
_async_notify_subscribers(
self.hass,
ChatLogEventType.CONTENT_ADDED,
{"conversation_id": self.conversation_id, "content": content.as_dict()},
)
async def async_add_assistant_content(
self,
@@ -317,6 +437,14 @@ class ChatLog:
tool_result=tool_result,
)
self.content.append(response_content)
_async_notify_subscribers(
self.hass,
ChatLogEventType.CONTENT_ADDED,
{
"conversation_id": self.conversation_id,
"content": response_content.as_dict(),
},
)
yield response_content
async def async_add_delta_content_stream(
@@ -593,6 +721,11 @@ class ChatLog:
self.llm_api = llm_api
self.extra_system_prompt = extra_system_prompt
self.content[0] = SystemContent(content=prompt)
_async_notify_subscribers(
self.hass,
ChatLogEventType.UPDATED,
{"conversation_id": self.conversation_id, "chat_log": self.as_dict()},
)
LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from enum import IntFlag
from enum import IntFlag, StrEnum
from typing import TYPE_CHECKING
from homeassistant.util.hass_dict import HassKey
@@ -30,3 +30,12 @@ class ConversationEntityFeature(IntFlag):
"""Supported features of the conversation entity."""
CONTROL = 1
class ChatLogEventType(StrEnum):
"""Chat log event type."""
CREATED = "created"
UPDATED = "updated"
DELETED = "deleted"
CONTENT_ADDED = "content_added"

View File

@@ -20,6 +20,7 @@ from .agent_manager import (
async_get_agent,
get_agent_manager,
)
from .chat_log import async_subscribe_chat_logs
from .const import DATA_COMPONENT
from .entity import ConversationEntity
from .models import ConversationInput
@@ -35,6 +36,7 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_list_sentences)
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
websocket_api.async_register_command(hass, websocket_hass_agent_language_scores)
websocket_api.async_register_command(hass, websocket_subscribe_chat_logs)
@websocket_api.websocket_command(
@@ -265,3 +267,28 @@ class ConversationProcessView(http.HomeAssistantView):
)
return self.json(result.as_dict())
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/chat_log/subscribe",
}
)
@websocket_api.require_admin
def websocket_subscribe_chat_logs(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Subscribe to all chat logs."""
@callback
def forward_events(event_type: str, data: dict) -> None:
"""Forward chat log events to websocket connection."""
connection.send_message(
{"type": "event", "event_type": event_type, "data": data}
)
unsubscribe = async_subscribe_chat_logs(hass, forward_events)
connection.subscriptions[msg["id"]] = unsubscribe
connection.send_result(msg["id"])

View File

@@ -12,7 +12,7 @@ if TYPE_CHECKING:
DOMAIN = ha.DOMAIN
DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entites")
DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entities")
DATA_STOP_HANDLER = f"{DOMAIN}.stop_handler"
SERVICE_HOMEASSISTANT_STOP: Final = "stop"

View File

@@ -2,6 +2,8 @@
from dataclasses import asdict
from datetime import timedelta
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
@@ -16,7 +18,12 @@ from homeassistant.components.conversation import (
UserContent,
async_get_chat_log,
)
from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
from homeassistant.components.conversation.chat_log import (
DATA_CHAT_LOGS,
Attachment,
ChatLogEventType,
async_subscribe_chat_logs,
)
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm
@@ -841,3 +848,168 @@ async def test_chat_log_continue_conversation(
)
)
assert chat_log.continue_conversation is True
async def test_chat_log_subscription(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test comprehensive chat log subscription functionality."""
# Track all events received
received_events = []
def event_callback(event_type: ChatLogEventType, data: dict[str, Any]) -> None:
"""Track received events."""
received_events.append((event_type, data))
# Subscribe to chat log events
unsubscribe = async_subscribe_chat_logs(hass, event_callback)
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
conversation_id = session.conversation_id
# Test adding different types of content and verify events are sent
chat_log.async_add_user_content(
UserContent(
content="Check this image",
attachments=[
Attachment(
mime_type="image/jpeg",
media_content_id="media-source://bla",
path=Path("test_image.jpg"),
)
],
)
)
# Check user content with attachments event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
user_event = received_events[-1][1]["content"]
assert user_event["content"] == "Check this image"
assert len(user_event["attachments"]) == 1
assert user_event["attachments"][0]["mime_type"] == "image/jpeg"
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent", content="Hello! How can I help you?"
)
)
# Check basic assistant content event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
basic_event = received_events[-1][1]["content"]
assert basic_event["content"] == "Hello! How can I help you?"
assert basic_event["agent_id"] == "test-agent"
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent",
content="Let me think about that...",
thinking_content="I need to analyze the user's request carefully.",
)
)
# Check assistant content with thinking event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
thinking_event = received_events[-1][1]["content"]
assert (
thinking_event["thinking_content"]
== "I need to analyze the user's request carefully."
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent",
content="Here's some data:",
native={"type": "chart", "data": [1, 2, 3, 4, 5]},
)
)
# Check assistant content with native event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
native_event = received_events[-1][1]["content"]
assert native_event["content"] == "Here's some data:"
assert native_event["agent_id"] == "test-agent"
chat_log.async_add_assistant_content_without_tools(
ToolResultContent(
agent_id="test-agent",
tool_call_id="test-tool-call-123",
tool_name="test_tool",
tool_result="Tool execution completed successfully",
)
)
# Check tool result content event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
tool_result_event = received_events[-1][1]["content"]
assert tool_result_event["tool_name"] == "test_tool"
assert (
tool_result_event["tool_result"] == "Tool execution completed successfully"
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent",
content="I'll call an external service",
tool_calls=[
llm.ToolInput(
id="external-tool-call-123",
tool_name="external_api_call",
tool_args={"endpoint": "https://api.example.com/data"},
external=True,
)
],
)
)
# Check external tool call event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
external_tool_event = received_events[-1][1]["content"]
assert len(external_tool_event["tool_calls"]) == 1
assert external_tool_event["tool_calls"][0].tool_name == "external_api_call"
# Verify we received the expected events
# Should have: 1 CREATED event + 7 CONTENT_ADDED events
assert len(received_events) == 8
# Check the first event is CREATED
assert received_events[0][0] == ChatLogEventType.CREATED
assert received_events[0][1]["chat_log"]["conversation_id"] == conversation_id
# Check the second event is CONTENT_ADDED (from mock_conversation_input)
assert received_events[1][0] == ChatLogEventType.CONTENT_ADDED
assert received_events[1][1]["conversation_id"] == conversation_id
# Test cleanup functionality
assert conversation_id in hass.data[chat_session.DATA_CHAT_SESSION]
# Set the last updated to be older than the timeout
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
)
async_fire_time_changed(
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
)
# Check that DELETED event was sent
assert received_events[-1][0] == ChatLogEventType.DELETED
assert received_events[-1][1]["conversation_id"] == conversation_id
# Test that unsubscribing stops receiving events
events_before_unsubscribe = len(received_events)
unsubscribe()
# Create a new session and add content - should not receive events
with (
chat_session.async_get_chat_session(hass) as session2,
async_get_chat_log(hass, session2, mock_conversation_input) as chat_log2,
):
chat_log2.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent", content="This should not be received"
)
)
# Verify no new events were received after unsubscribing
assert len(received_events) == events_before_unsubscribe

View File

@@ -7,17 +7,26 @@ from unittest.mock import patch
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.conversation import async_get_agent
from homeassistant.components.conversation import (
ConversationInput,
async_get_agent,
async_get_chat_log,
)
from homeassistant.components.conversation.const import HOME_ASSISTANT_AGENT
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant
from homeassistant.helpers import area_registry as ar, entity_registry as er, intent
from homeassistant.helpers import (
area_registry as ar,
chat_session,
entity_registry as er,
intent,
)
from homeassistant.setup import async_setup_component
from . import MockAgent
from tests.common import async_mock_service
from tests.common import MockUser, async_mock_service
from tests.typing import ClientSessionGenerator, WebSocketGenerator
AGENT_ID_OPTIONS = [
@@ -590,3 +599,79 @@ async def test_ws_hass_language_scores_with_filter(
# GB English should be preferred
result = msg["result"]
assert result["preferred_language"] == "en-GB"
async def test_ws_chat_log_subscription(
hass: HomeAssistant,
init_components,
mock_conversation_input: ConversationInput,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that we can subscribe to chat logs."""
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"})
msg = await client.receive_json()
assert msg["success"]
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input),
):
conversation_id = session.conversation_id
# We should receive 3 events:
# 1. The CREATED event (fired before content is added)
msg = await client.receive_json()
assert msg == {
"type": "event",
"event_type": "created",
"data": {
"chat_log": {
"conversation_id": conversation_id,
"continue_conversation": False,
}
},
}
# 2. The user input content added event
msg = await client.receive_json()
assert msg == {
"type": "event",
"event_type": "content_added",
"data": {
"conversation_id": conversation_id,
"content": {
"content": "Hello",
"role": "user",
},
},
}
# 3. The DELETED event (since no assistant message was added)
msg = await client.receive_json()
assert msg == {
"type": "event",
"event_type": "deleted",
"data": {
"conversation_id": conversation_id,
},
}
async def test_ws_chat_log_subscription_requires_admin(
hass: HomeAssistant,
init_components,
hass_ws_client: WebSocketGenerator,
hass_admin_user: MockUser,
) -> None:
"""Test that chat log subscription requires admin access."""
# Create a non-admin user
hass_admin_user.groups = []
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"})
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "unauthorized"