From e1383d30e7a1bde65a23f1392e5ffa9ebc955cac Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 25 Oct 2025 22:36:31 -0400 Subject: [PATCH] Add websocket command to interact with chat logs --- .../components/conversation/chat_log.py | 137 +++++++++++++- .../components/conversation/const.py | 11 +- homeassistant/components/conversation/http.py | 27 +++ .../components/homeassistant/const.py | 2 +- .../components/conversation/test_chat_log.py | 174 +++++++++++++++++- tests/components/conversation/test_http.py | 91 ++++++++- 6 files changed, 434 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index 736bf128e60..96cfa62f4c3 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -20,10 +20,13 @@ from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import JsonObjectType from . import trace +from .const import ChatLogEventType from .models import ConversationInput, ConversationResult DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs") - +SUBSCRIPTIONS: HassKey[list[Callable[[ChatLogEventType, dict[str, Any]], None]]] = ( + HassKey("conversation_chat_log_subscriptions") +) LOGGER = logging.getLogger(__name__) current_chat_log: ContextVar[ChatLog | None] = ContextVar( @@ -31,6 +34,37 @@ current_chat_log: ContextVar[ChatLog | None] = ContextVar( ) +@callback +def async_subscribe_chat_logs( + hass: HomeAssistant, + callback_func: Callable[[ChatLogEventType, dict[str, Any]], None], +) -> Callable[[], None]: + """Subscribe to all chat logs.""" + subscriptions = hass.data.get(SUBSCRIPTIONS) + if subscriptions is None: + subscriptions = [] + hass.data[SUBSCRIPTIONS] = subscriptions + + subscriptions.append(callback_func) + + @callback + def unsubscribe() -> None: + """Unsubscribe from chat logs.""" + subscriptions.remove(callback_func) + + return unsubscribe + + +@callback +def _async_notify_subscribers( + hass: HomeAssistant, event_type: ChatLogEventType, data: dict[str, Any] +) -> None: + """Notify subscribers of a chat log event.""" + if subscriptions := hass.data.get(SUBSCRIPTIONS): + for callback_func in subscriptions: + callback_func(event_type, data) + + @contextmanager def async_get_chat_log( hass: HomeAssistant, @@ -63,6 +97,8 @@ def async_get_chat_log( all_chat_logs = {} hass.data[DATA_CHAT_LOGS] = all_chat_logs + is_new_log = session.conversation_id not in all_chat_logs + if chat_log := all_chat_logs.get(session.conversation_id): chat_log = replace(chat_log, content=chat_log.content.copy()) else: @@ -71,6 +107,12 @@ def async_get_chat_log( if chat_log_delta_listener: chat_log.delta_listener = chat_log_delta_listener + # Fire CREATED event for new chat logs before any content is added + if is_new_log: + _async_notify_subscribers( + hass, ChatLogEventType.CREATED, {"chat_log": chat_log.as_dict()} + ) + if user_input is not None: chat_log.async_add_user_content(UserContent(content=user_input.text)) @@ -84,14 +126,26 @@ def async_get_chat_log( LOGGER.debug( "Chat Log opened but no assistant message was added, ignoring update" ) + # If this was a new log but nothing was added, fire DELETED to clean up + if is_new_log: + _async_notify_subscribers( + hass, + ChatLogEventType.DELETED, + {"conversation_id": session.conversation_id}, + ) return - if session.conversation_id not in all_chat_logs: + if is_new_log: @callback def do_cleanup() -> None: """Handle cleanup.""" all_chat_logs.pop(session.conversation_id) + _async_notify_subscribers( + hass, + ChatLogEventType.DELETED, + {"conversation_id": session.conversation_id}, + ) session.async_on_cleanup(do_cleanup) @@ -100,6 +154,13 @@ def async_get_chat_log( all_chat_logs[session.conversation_id] = chat_log + # For new logs, CREATED was already fired before content was added + # For existing logs, fire UPDATED + if not is_new_log: + _async_notify_subscribers( + hass, ChatLogEventType.UPDATED, {"chat_log": chat_log.as_dict()} + ) + class ConverseError(HomeAssistantError): """Error during initialization of conversation. @@ -130,6 +191,10 @@ class SystemContent: role: Literal["system"] = field(init=False, default="system") content: str + def as_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the content.""" + return {"role": self.role, "content": self.content} + @dataclass(frozen=True) class UserContent: @@ -139,6 +204,15 @@ class UserContent: content: str attachments: list[Attachment] | None = field(default=None) + def as_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the content.""" + result: dict[str, Any] = {"role": self.role, "content": self.content} + if self.attachments: + result["attachments"] = [ + attachment.as_dict() for attachment in self.attachments + ] + return result + @dataclass(frozen=True) class Attachment: @@ -153,6 +227,14 @@ class Attachment: path: Path """Path to the attachment on disk.""" + def as_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the attachment.""" + return { + "media_content_id": self.media_content_id, + "mime_type": self.mime_type, + "path": str(self.path), + } + @dataclass(frozen=True) class AssistantContent: @@ -165,6 +247,17 @@ class AssistantContent: tool_calls: list[llm.ToolInput] | None = None native: Any = None + def as_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the content.""" + result: dict[str, Any] = {"role": self.role, "agent_id": self.agent_id} + if self.content: + result["content"] = self.content + if self.thinking_content: + result["thinking_content"] = self.thinking_content + if self.tool_calls: + result["tool_calls"] = self.tool_calls + return result + @dataclass(frozen=True) class ToolResultContent: @@ -176,6 +269,16 @@ class ToolResultContent: tool_name: str tool_result: JsonObjectType + def as_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the content.""" + return { + "role": self.role, + "agent_id": self.agent_id, + "tool_call_id": self.tool_call_id, + "tool_name": self.tool_name, + "tool_result": self.tool_result, + } + type Content = SystemContent | UserContent | AssistantContent | ToolResultContent @@ -211,6 +314,13 @@ class ChatLog: delta_listener: Callable[[ChatLog, dict], None] | None = None llm_input_provided_index = 0 + def as_dict(self) -> dict[str, Any]: + """Return a dictionary representation of the chat log.""" + return { + "conversation_id": self.conversation_id, + "continue_conversation": self.continue_conversation, + } + @property def continue_conversation(self) -> bool: """Return whether the conversation should continue.""" @@ -241,6 +351,11 @@ class ChatLog: """Add user content to the log.""" LOGGER.debug("Adding user content: %s", content) self.content.append(content) + _async_notify_subscribers( + self.hass, + ChatLogEventType.CONTENT_ADDED, + {"conversation_id": self.conversation_id, "content": content.as_dict()}, + ) @callback def async_add_assistant_content_without_tools( @@ -259,6 +374,11 @@ class ChatLog: ): raise ValueError("Non-external tool calls not allowed") self.content.append(content) + _async_notify_subscribers( + self.hass, + ChatLogEventType.CONTENT_ADDED, + {"conversation_id": self.conversation_id, "content": content.as_dict()}, + ) async def async_add_assistant_content( self, @@ -317,6 +437,14 @@ class ChatLog: tool_result=tool_result, ) self.content.append(response_content) + _async_notify_subscribers( + self.hass, + ChatLogEventType.CONTENT_ADDED, + { + "conversation_id": self.conversation_id, + "content": response_content.as_dict(), + }, + ) yield response_content async def async_add_delta_content_stream( @@ -593,6 +721,11 @@ class ChatLog: self.llm_api = llm_api self.extra_system_prompt = extra_system_prompt self.content[0] = SystemContent(content=prompt) + _async_notify_subscribers( + self.hass, + ChatLogEventType.UPDATED, + {"conversation_id": self.conversation_id, "chat_log": self.as_dict()}, + ) LOGGER.debug("Prompt: %s", self.content) LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) diff --git a/homeassistant/components/conversation/const.py b/homeassistant/components/conversation/const.py index e1029de9918..53e99b34ecc 100644 --- a/homeassistant/components/conversation/const.py +++ b/homeassistant/components/conversation/const.py @@ -2,7 +2,7 @@ from __future__ import annotations -from enum import IntFlag +from enum import IntFlag, StrEnum from typing import TYPE_CHECKING from homeassistant.util.hass_dict import HassKey @@ -30,3 +30,12 @@ class ConversationEntityFeature(IntFlag): """Supported features of the conversation entity.""" CONTROL = 1 + + +class ChatLogEventType(StrEnum): + """Chat log event type.""" + + CREATED = "created" + UPDATED = "updated" + DELETED = "deleted" + CONTENT_ADDED = "content_added" diff --git a/homeassistant/components/conversation/http.py b/homeassistant/components/conversation/http.py index 9d3eb35a7e3..be823908e6b 100644 --- a/homeassistant/components/conversation/http.py +++ b/homeassistant/components/conversation/http.py @@ -20,6 +20,7 @@ from .agent_manager import ( async_get_agent, get_agent_manager, ) +from .chat_log import async_subscribe_chat_logs from .const import DATA_COMPONENT from .entity import ConversationEntity from .models import ConversationInput @@ -35,6 +36,7 @@ def async_setup(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_list_sentences) websocket_api.async_register_command(hass, websocket_hass_agent_debug) websocket_api.async_register_command(hass, websocket_hass_agent_language_scores) + websocket_api.async_register_command(hass, websocket_subscribe_chat_logs) @websocket_api.websocket_command( @@ -265,3 +267,28 @@ class ConversationProcessView(http.HomeAssistantView): ) return self.json(result.as_dict()) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "conversation/chat_log/subscribe", + } +) +@websocket_api.require_admin +def websocket_subscribe_chat_logs( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Subscribe to all chat logs.""" + + @callback + def forward_events(event_type: str, data: dict) -> None: + """Forward chat log events to websocket connection.""" + connection.send_message( + {"type": "event", "event_type": event_type, "data": data} + ) + + unsubscribe = async_subscribe_chat_logs(hass, forward_events) + connection.subscriptions[msg["id"]] = unsubscribe + connection.send_result(msg["id"]) diff --git a/homeassistant/components/homeassistant/const.py b/homeassistant/components/homeassistant/const.py index 7fad6728a74..3ca8a14cce7 100644 --- a/homeassistant/components/homeassistant/const.py +++ b/homeassistant/components/homeassistant/const.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: DOMAIN = ha.DOMAIN -DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entites") +DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entities") DATA_STOP_HANDLER = f"{DOMAIN}.stop_handler" SERVICE_HOMEASSISTANT_STOP: Final = "stop" diff --git a/tests/components/conversation/test_chat_log.py b/tests/components/conversation/test_chat_log.py index 3fc13e93508..7114eb07d9b 100644 --- a/tests/components/conversation/test_chat_log.py +++ b/tests/components/conversation/test_chat_log.py @@ -2,6 +2,8 @@ from dataclasses import asdict from datetime import timedelta +from pathlib import Path +from typing import Any from unittest.mock import AsyncMock, Mock, patch import pytest @@ -16,7 +18,12 @@ from homeassistant.components.conversation import ( UserContent, async_get_chat_log, ) -from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS +from homeassistant.components.conversation.chat_log import ( + DATA_CHAT_LOGS, + Attachment, + ChatLogEventType, + async_subscribe_chat_logs, +) from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import chat_session, llm @@ -841,3 +848,168 @@ async def test_chat_log_continue_conversation( ) ) assert chat_log.continue_conversation is True + + +async def test_chat_log_subscription( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Test comprehensive chat log subscription functionality.""" + + # Track all events received + received_events = [] + + def event_callback(event_type: ChatLogEventType, data: dict[str, Any]) -> None: + """Track received events.""" + received_events.append((event_type, data)) + + # Subscribe to chat log events + unsubscribe = async_subscribe_chat_logs(hass, event_callback) + + with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input) as chat_log, + ): + conversation_id = session.conversation_id + + # Test adding different types of content and verify events are sent + chat_log.async_add_user_content( + UserContent( + content="Check this image", + attachments=[ + Attachment( + mime_type="image/jpeg", + media_content_id="media-source://bla", + path=Path("test_image.jpg"), + ) + ], + ) + ) + # Check user content with attachments event + assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED + user_event = received_events[-1][1]["content"] + assert user_event["content"] == "Check this image" + assert len(user_event["attachments"]) == 1 + assert user_event["attachments"][0]["mime_type"] == "image/jpeg" + + chat_log.async_add_assistant_content_without_tools( + AssistantContent( + agent_id="test-agent", content="Hello! How can I help you?" + ) + ) + # Check basic assistant content event + assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED + basic_event = received_events[-1][1]["content"] + assert basic_event["content"] == "Hello! How can I help you?" + assert basic_event["agent_id"] == "test-agent" + + chat_log.async_add_assistant_content_without_tools( + AssistantContent( + agent_id="test-agent", + content="Let me think about that...", + thinking_content="I need to analyze the user's request carefully.", + ) + ) + # Check assistant content with thinking event + assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED + thinking_event = received_events[-1][1]["content"] + assert ( + thinking_event["thinking_content"] + == "I need to analyze the user's request carefully." + ) + + chat_log.async_add_assistant_content_without_tools( + AssistantContent( + agent_id="test-agent", + content="Here's some data:", + native={"type": "chart", "data": [1, 2, 3, 4, 5]}, + ) + ) + # Check assistant content with native event + assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED + native_event = received_events[-1][1]["content"] + assert native_event["content"] == "Here's some data:" + assert native_event["agent_id"] == "test-agent" + + chat_log.async_add_assistant_content_without_tools( + ToolResultContent( + agent_id="test-agent", + tool_call_id="test-tool-call-123", + tool_name="test_tool", + tool_result="Tool execution completed successfully", + ) + ) + # Check tool result content event + assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED + tool_result_event = received_events[-1][1]["content"] + assert tool_result_event["tool_name"] == "test_tool" + assert ( + tool_result_event["tool_result"] == "Tool execution completed successfully" + ) + + chat_log.async_add_assistant_content_without_tools( + AssistantContent( + agent_id="test-agent", + content="I'll call an external service", + tool_calls=[ + llm.ToolInput( + id="external-tool-call-123", + tool_name="external_api_call", + tool_args={"endpoint": "https://api.example.com/data"}, + external=True, + ) + ], + ) + ) + # Check external tool call event + assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED + external_tool_event = received_events[-1][1]["content"] + assert len(external_tool_event["tool_calls"]) == 1 + assert external_tool_event["tool_calls"][0].tool_name == "external_api_call" + + # Verify we received the expected events + # Should have: 1 CREATED event + 7 CONTENT_ADDED events + assert len(received_events) == 8 + + # Check the first event is CREATED + assert received_events[0][0] == ChatLogEventType.CREATED + assert received_events[0][1]["chat_log"]["conversation_id"] == conversation_id + + # Check the second event is CONTENT_ADDED (from mock_conversation_input) + assert received_events[1][0] == ChatLogEventType.CONTENT_ADDED + assert received_events[1][1]["conversation_id"] == conversation_id + + # Test cleanup functionality + assert conversation_id in hass.data[chat_session.DATA_CHAT_SESSION] + + # Set the last updated to be older than the timeout + hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = ( + dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + ) + + async_fire_time_changed( + hass, + dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1), + ) + + # Check that DELETED event was sent + assert received_events[-1][0] == ChatLogEventType.DELETED + assert received_events[-1][1]["conversation_id"] == conversation_id + + # Test that unsubscribing stops receiving events + events_before_unsubscribe = len(received_events) + unsubscribe() + + # Create a new session and add content - should not receive events + with ( + chat_session.async_get_chat_session(hass) as session2, + async_get_chat_log(hass, session2, mock_conversation_input) as chat_log2, + ): + chat_log2.async_add_assistant_content_without_tools( + AssistantContent( + agent_id="test-agent", content="This should not be received" + ) + ) + + # Verify no new events were received after unsubscribing + assert len(received_events) == events_before_unsubscribe diff --git a/tests/components/conversation/test_http.py b/tests/components/conversation/test_http.py index 24fc4d1b135..2efeb395570 100644 --- a/tests/components/conversation/test_http.py +++ b/tests/components/conversation/test_http.py @@ -7,17 +7,26 @@ from unittest.mock import patch import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components.conversation import async_get_agent +from homeassistant.components.conversation import ( + ConversationInput, + async_get_agent, + async_get_chat_log, +) from homeassistant.components.conversation.const import HOME_ASSISTANT_AGENT from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.core import HomeAssistant -from homeassistant.helpers import area_registry as ar, entity_registry as er, intent +from homeassistant.helpers import ( + area_registry as ar, + chat_session, + entity_registry as er, + intent, +) from homeassistant.setup import async_setup_component from . import MockAgent -from tests.common import async_mock_service +from tests.common import MockUser, async_mock_service from tests.typing import ClientSessionGenerator, WebSocketGenerator AGENT_ID_OPTIONS = [ @@ -590,3 +599,79 @@ async def test_ws_hass_language_scores_with_filter( # GB English should be preferred result = msg["result"] assert result["preferred_language"] == "en-GB" + + +async def test_ws_chat_log_subscription( + hass: HomeAssistant, + init_components, + mock_conversation_input: ConversationInput, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test that we can subscribe to chat logs.""" + client = await hass_ws_client(hass) + + await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"}) + msg = await client.receive_json() + assert msg["success"] + + with ( + chat_session.async_get_chat_session(hass) as session, + async_get_chat_log(hass, session, mock_conversation_input), + ): + conversation_id = session.conversation_id + + # We should receive 3 events: + # 1. The CREATED event (fired before content is added) + msg = await client.receive_json() + assert msg == { + "type": "event", + "event_type": "created", + "data": { + "chat_log": { + "conversation_id": conversation_id, + "continue_conversation": False, + } + }, + } + + # 2. The user input content added event + msg = await client.receive_json() + assert msg == { + "type": "event", + "event_type": "content_added", + "data": { + "conversation_id": conversation_id, + "content": { + "content": "Hello", + "role": "user", + }, + }, + } + + # 3. The DELETED event (since no assistant message was added) + msg = await client.receive_json() + assert msg == { + "type": "event", + "event_type": "deleted", + "data": { + "conversation_id": conversation_id, + }, + } + + +async def test_ws_chat_log_subscription_requires_admin( + hass: HomeAssistant, + init_components, + hass_ws_client: WebSocketGenerator, + hass_admin_user: MockUser, +) -> None: + """Test that chat log subscription requires admin access.""" + # Create a non-admin user + hass_admin_user.groups = [] + client = await hass_ws_client(hass) + + await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"}) + msg = await client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "unauthorized"