Add websocket command to interact with chat logs

This commit is contained in:
Paulus Schoutsen
2025-10-25 22:36:31 -04:00
parent 4c9810a10e
commit e1383d30e7
6 changed files with 434 additions and 8 deletions

View File

@@ -20,10 +20,13 @@ from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
from . import trace from . import trace
from .const import ChatLogEventType
from .models import ConversationInput, ConversationResult from .models import ConversationInput, ConversationResult
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs") DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
SUBSCRIPTIONS: HassKey[list[Callable[[ChatLogEventType, dict[str, Any]], None]]] = (
HassKey("conversation_chat_log_subscriptions")
)
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
current_chat_log: ContextVar[ChatLog | None] = ContextVar( current_chat_log: ContextVar[ChatLog | None] = ContextVar(
@@ -31,6 +34,37 @@ current_chat_log: ContextVar[ChatLog | None] = ContextVar(
) )
@callback
def async_subscribe_chat_logs(
hass: HomeAssistant,
callback_func: Callable[[ChatLogEventType, dict[str, Any]], None],
) -> Callable[[], None]:
"""Subscribe to all chat logs."""
subscriptions = hass.data.get(SUBSCRIPTIONS)
if subscriptions is None:
subscriptions = []
hass.data[SUBSCRIPTIONS] = subscriptions
subscriptions.append(callback_func)
@callback
def unsubscribe() -> None:
"""Unsubscribe from chat logs."""
subscriptions.remove(callback_func)
return unsubscribe
@callback
def _async_notify_subscribers(
hass: HomeAssistant, event_type: ChatLogEventType, data: dict[str, Any]
) -> None:
"""Notify subscribers of a chat log event."""
if subscriptions := hass.data.get(SUBSCRIPTIONS):
for callback_func in subscriptions:
callback_func(event_type, data)
@contextmanager @contextmanager
def async_get_chat_log( def async_get_chat_log(
hass: HomeAssistant, hass: HomeAssistant,
@@ -63,6 +97,8 @@ def async_get_chat_log(
all_chat_logs = {} all_chat_logs = {}
hass.data[DATA_CHAT_LOGS] = all_chat_logs hass.data[DATA_CHAT_LOGS] = all_chat_logs
is_new_log = session.conversation_id not in all_chat_logs
if chat_log := all_chat_logs.get(session.conversation_id): if chat_log := all_chat_logs.get(session.conversation_id):
chat_log = replace(chat_log, content=chat_log.content.copy()) chat_log = replace(chat_log, content=chat_log.content.copy())
else: else:
@@ -71,6 +107,12 @@ def async_get_chat_log(
if chat_log_delta_listener: if chat_log_delta_listener:
chat_log.delta_listener = chat_log_delta_listener chat_log.delta_listener = chat_log_delta_listener
# Fire CREATED event for new chat logs before any content is added
if is_new_log:
_async_notify_subscribers(
hass, ChatLogEventType.CREATED, {"chat_log": chat_log.as_dict()}
)
if user_input is not None: if user_input is not None:
chat_log.async_add_user_content(UserContent(content=user_input.text)) chat_log.async_add_user_content(UserContent(content=user_input.text))
@@ -84,14 +126,26 @@ def async_get_chat_log(
LOGGER.debug( LOGGER.debug(
"Chat Log opened but no assistant message was added, ignoring update" "Chat Log opened but no assistant message was added, ignoring update"
) )
# If this was a new log but nothing was added, fire DELETED to clean up
if is_new_log:
_async_notify_subscribers(
hass,
ChatLogEventType.DELETED,
{"conversation_id": session.conversation_id},
)
return return
if session.conversation_id not in all_chat_logs: if is_new_log:
@callback @callback
def do_cleanup() -> None: def do_cleanup() -> None:
"""Handle cleanup.""" """Handle cleanup."""
all_chat_logs.pop(session.conversation_id) all_chat_logs.pop(session.conversation_id)
_async_notify_subscribers(
hass,
ChatLogEventType.DELETED,
{"conversation_id": session.conversation_id},
)
session.async_on_cleanup(do_cleanup) session.async_on_cleanup(do_cleanup)
@@ -100,6 +154,13 @@ def async_get_chat_log(
all_chat_logs[session.conversation_id] = chat_log all_chat_logs[session.conversation_id] = chat_log
# For new logs, CREATED was already fired before content was added
# For existing logs, fire UPDATED
if not is_new_log:
_async_notify_subscribers(
hass, ChatLogEventType.UPDATED, {"chat_log": chat_log.as_dict()}
)
class ConverseError(HomeAssistantError): class ConverseError(HomeAssistantError):
"""Error during initialization of conversation. """Error during initialization of conversation.
@@ -130,6 +191,10 @@ class SystemContent:
role: Literal["system"] = field(init=False, default="system") role: Literal["system"] = field(init=False, default="system")
content: str content: str
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
return {"role": self.role, "content": self.content}
@dataclass(frozen=True) @dataclass(frozen=True)
class UserContent: class UserContent:
@@ -139,6 +204,15 @@ class UserContent:
content: str content: str
attachments: list[Attachment] | None = field(default=None) attachments: list[Attachment] | None = field(default=None)
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
result: dict[str, Any] = {"role": self.role, "content": self.content}
if self.attachments:
result["attachments"] = [
attachment.as_dict() for attachment in self.attachments
]
return result
@dataclass(frozen=True) @dataclass(frozen=True)
class Attachment: class Attachment:
@@ -153,6 +227,14 @@ class Attachment:
path: Path path: Path
"""Path to the attachment on disk.""" """Path to the attachment on disk."""
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the attachment."""
return {
"media_content_id": self.media_content_id,
"mime_type": self.mime_type,
"path": str(self.path),
}
@dataclass(frozen=True) @dataclass(frozen=True)
class AssistantContent: class AssistantContent:
@@ -165,6 +247,17 @@ class AssistantContent:
tool_calls: list[llm.ToolInput] | None = None tool_calls: list[llm.ToolInput] | None = None
native: Any = None native: Any = None
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
result: dict[str, Any] = {"role": self.role, "agent_id": self.agent_id}
if self.content:
result["content"] = self.content
if self.thinking_content:
result["thinking_content"] = self.thinking_content
if self.tool_calls:
result["tool_calls"] = self.tool_calls
return result
@dataclass(frozen=True) @dataclass(frozen=True)
class ToolResultContent: class ToolResultContent:
@@ -176,6 +269,16 @@ class ToolResultContent:
tool_name: str tool_name: str
tool_result: JsonObjectType tool_result: JsonObjectType
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the content."""
return {
"role": self.role,
"agent_id": self.agent_id,
"tool_call_id": self.tool_call_id,
"tool_name": self.tool_name,
"tool_result": self.tool_result,
}
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
@@ -211,6 +314,13 @@ class ChatLog:
delta_listener: Callable[[ChatLog, dict], None] | None = None delta_listener: Callable[[ChatLog, dict], None] | None = None
llm_input_provided_index = 0 llm_input_provided_index = 0
def as_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of the chat log."""
return {
"conversation_id": self.conversation_id,
"continue_conversation": self.continue_conversation,
}
@property @property
def continue_conversation(self) -> bool: def continue_conversation(self) -> bool:
"""Return whether the conversation should continue.""" """Return whether the conversation should continue."""
@@ -241,6 +351,11 @@ class ChatLog:
"""Add user content to the log.""" """Add user content to the log."""
LOGGER.debug("Adding user content: %s", content) LOGGER.debug("Adding user content: %s", content)
self.content.append(content) self.content.append(content)
_async_notify_subscribers(
self.hass,
ChatLogEventType.CONTENT_ADDED,
{"conversation_id": self.conversation_id, "content": content.as_dict()},
)
@callback @callback
def async_add_assistant_content_without_tools( def async_add_assistant_content_without_tools(
@@ -259,6 +374,11 @@ class ChatLog:
): ):
raise ValueError("Non-external tool calls not allowed") raise ValueError("Non-external tool calls not allowed")
self.content.append(content) self.content.append(content)
_async_notify_subscribers(
self.hass,
ChatLogEventType.CONTENT_ADDED,
{"conversation_id": self.conversation_id, "content": content.as_dict()},
)
async def async_add_assistant_content( async def async_add_assistant_content(
self, self,
@@ -317,6 +437,14 @@ class ChatLog:
tool_result=tool_result, tool_result=tool_result,
) )
self.content.append(response_content) self.content.append(response_content)
_async_notify_subscribers(
self.hass,
ChatLogEventType.CONTENT_ADDED,
{
"conversation_id": self.conversation_id,
"content": response_content.as_dict(),
},
)
yield response_content yield response_content
async def async_add_delta_content_stream( async def async_add_delta_content_stream(
@@ -593,6 +721,11 @@ class ChatLog:
self.llm_api = llm_api self.llm_api = llm_api
self.extra_system_prompt = extra_system_prompt self.extra_system_prompt = extra_system_prompt
self.content[0] = SystemContent(content=prompt) self.content[0] = SystemContent(content=prompt)
_async_notify_subscribers(
self.hass,
ChatLogEventType.UPDATED,
{"conversation_id": self.conversation_id, "chat_log": self.as_dict()},
)
LOGGER.debug("Prompt: %s", self.content) LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None) LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from enum import IntFlag from enum import IntFlag, StrEnum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
@@ -30,3 +30,12 @@ class ConversationEntityFeature(IntFlag):
"""Supported features of the conversation entity.""" """Supported features of the conversation entity."""
CONTROL = 1 CONTROL = 1
class ChatLogEventType(StrEnum):
"""Chat log event type."""
CREATED = "created"
UPDATED = "updated"
DELETED = "deleted"
CONTENT_ADDED = "content_added"

View File

@@ -20,6 +20,7 @@ from .agent_manager import (
async_get_agent, async_get_agent,
get_agent_manager, get_agent_manager,
) )
from .chat_log import async_subscribe_chat_logs
from .const import DATA_COMPONENT from .const import DATA_COMPONENT
from .entity import ConversationEntity from .entity import ConversationEntity
from .models import ConversationInput from .models import ConversationInput
@@ -35,6 +36,7 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_list_sentences) websocket_api.async_register_command(hass, websocket_list_sentences)
websocket_api.async_register_command(hass, websocket_hass_agent_debug) websocket_api.async_register_command(hass, websocket_hass_agent_debug)
websocket_api.async_register_command(hass, websocket_hass_agent_language_scores) websocket_api.async_register_command(hass, websocket_hass_agent_language_scores)
websocket_api.async_register_command(hass, websocket_subscribe_chat_logs)
@websocket_api.websocket_command( @websocket_api.websocket_command(
@@ -265,3 +267,28 @@ class ConversationProcessView(http.HomeAssistantView):
) )
return self.json(result.as_dict()) return self.json(result.as_dict())
@websocket_api.websocket_command(
{
vol.Required("type"): "conversation/chat_log/subscribe",
}
)
@websocket_api.require_admin
def websocket_subscribe_chat_logs(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Subscribe to all chat logs."""
@callback
def forward_events(event_type: str, data: dict) -> None:
"""Forward chat log events to websocket connection."""
connection.send_message(
{"type": "event", "event_type": event_type, "data": data}
)
unsubscribe = async_subscribe_chat_logs(hass, forward_events)
connection.subscriptions[msg["id"]] = unsubscribe
connection.send_result(msg["id"])

View File

@@ -12,7 +12,7 @@ if TYPE_CHECKING:
DOMAIN = ha.DOMAIN DOMAIN = ha.DOMAIN
DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entites") DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entities")
DATA_STOP_HANDLER = f"{DOMAIN}.stop_handler" DATA_STOP_HANDLER = f"{DOMAIN}.stop_handler"
SERVICE_HOMEASSISTANT_STOP: Final = "stop" SERVICE_HOMEASSISTANT_STOP: Final = "stop"

View File

@@ -2,6 +2,8 @@
from dataclasses import asdict from dataclasses import asdict
from datetime import timedelta from datetime import timedelta
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
@@ -16,7 +18,12 @@ from homeassistant.components.conversation import (
UserContent, UserContent,
async_get_chat_log, async_get_chat_log,
) )
from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS from homeassistant.components.conversation.chat_log import (
DATA_CHAT_LOGS,
Attachment,
ChatLogEventType,
async_subscribe_chat_logs,
)
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm from homeassistant.helpers import chat_session, llm
@@ -841,3 +848,168 @@ async def test_chat_log_continue_conversation(
) )
) )
assert chat_log.continue_conversation is True assert chat_log.continue_conversation is True
async def test_chat_log_subscription(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test comprehensive chat log subscription functionality."""
# Track all events received
received_events = []
def event_callback(event_type: ChatLogEventType, data: dict[str, Any]) -> None:
"""Track received events."""
received_events.append((event_type, data))
# Subscribe to chat log events
unsubscribe = async_subscribe_chat_logs(hass, event_callback)
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
conversation_id = session.conversation_id
# Test adding different types of content and verify events are sent
chat_log.async_add_user_content(
UserContent(
content="Check this image",
attachments=[
Attachment(
mime_type="image/jpeg",
media_content_id="media-source://bla",
path=Path("test_image.jpg"),
)
],
)
)
# Check user content with attachments event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
user_event = received_events[-1][1]["content"]
assert user_event["content"] == "Check this image"
assert len(user_event["attachments"]) == 1
assert user_event["attachments"][0]["mime_type"] == "image/jpeg"
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent", content="Hello! How can I help you?"
)
)
# Check basic assistant content event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
basic_event = received_events[-1][1]["content"]
assert basic_event["content"] == "Hello! How can I help you?"
assert basic_event["agent_id"] == "test-agent"
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent",
content="Let me think about that...",
thinking_content="I need to analyze the user's request carefully.",
)
)
# Check assistant content with thinking event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
thinking_event = received_events[-1][1]["content"]
assert (
thinking_event["thinking_content"]
== "I need to analyze the user's request carefully."
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent",
content="Here's some data:",
native={"type": "chart", "data": [1, 2, 3, 4, 5]},
)
)
# Check assistant content with native event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
native_event = received_events[-1][1]["content"]
assert native_event["content"] == "Here's some data:"
assert native_event["agent_id"] == "test-agent"
chat_log.async_add_assistant_content_without_tools(
ToolResultContent(
agent_id="test-agent",
tool_call_id="test-tool-call-123",
tool_name="test_tool",
tool_result="Tool execution completed successfully",
)
)
# Check tool result content event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
tool_result_event = received_events[-1][1]["content"]
assert tool_result_event["tool_name"] == "test_tool"
assert (
tool_result_event["tool_result"] == "Tool execution completed successfully"
)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent",
content="I'll call an external service",
tool_calls=[
llm.ToolInput(
id="external-tool-call-123",
tool_name="external_api_call",
tool_args={"endpoint": "https://api.example.com/data"},
external=True,
)
],
)
)
# Check external tool call event
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
external_tool_event = received_events[-1][1]["content"]
assert len(external_tool_event["tool_calls"]) == 1
assert external_tool_event["tool_calls"][0].tool_name == "external_api_call"
# Verify we received the expected events
# Should have: 1 CREATED event + 7 CONTENT_ADDED events
assert len(received_events) == 8
# Check the first event is CREATED
assert received_events[0][0] == ChatLogEventType.CREATED
assert received_events[0][1]["chat_log"]["conversation_id"] == conversation_id
# Check the second event is CONTENT_ADDED (from mock_conversation_input)
assert received_events[1][0] == ChatLogEventType.CONTENT_ADDED
assert received_events[1][1]["conversation_id"] == conversation_id
# Test cleanup functionality
assert conversation_id in hass.data[chat_session.DATA_CHAT_SESSION]
# Set the last updated to be older than the timeout
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
)
async_fire_time_changed(
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
)
# Check that DELETED event was sent
assert received_events[-1][0] == ChatLogEventType.DELETED
assert received_events[-1][1]["conversation_id"] == conversation_id
# Test that unsubscribing stops receiving events
events_before_unsubscribe = len(received_events)
unsubscribe()
# Create a new session and add content - should not receive events
with (
chat_session.async_get_chat_session(hass) as session2,
async_get_chat_log(hass, session2, mock_conversation_input) as chat_log2,
):
chat_log2.async_add_assistant_content_without_tools(
AssistantContent(
agent_id="test-agent", content="This should not be received"
)
)
# Verify no new events were received after unsubscribing
assert len(received_events) == events_before_unsubscribe

View File

@@ -7,17 +7,26 @@ from unittest.mock import patch
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.conversation import async_get_agent from homeassistant.components.conversation import (
ConversationInput,
async_get_agent,
async_get_chat_log,
)
from homeassistant.components.conversation.const import HOME_ASSISTANT_AGENT from homeassistant.components.conversation.const import HOME_ASSISTANT_AGENT
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import area_registry as ar, entity_registry as er, intent from homeassistant.helpers import (
area_registry as ar,
chat_session,
entity_registry as er,
intent,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import MockAgent from . import MockAgent
from tests.common import async_mock_service from tests.common import MockUser, async_mock_service
from tests.typing import ClientSessionGenerator, WebSocketGenerator from tests.typing import ClientSessionGenerator, WebSocketGenerator
AGENT_ID_OPTIONS = [ AGENT_ID_OPTIONS = [
@@ -590,3 +599,79 @@ async def test_ws_hass_language_scores_with_filter(
# GB English should be preferred # GB English should be preferred
result = msg["result"] result = msg["result"]
assert result["preferred_language"] == "en-GB" assert result["preferred_language"] == "en-GB"
async def test_ws_chat_log_subscription(
hass: HomeAssistant,
init_components,
mock_conversation_input: ConversationInput,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that we can subscribe to chat logs."""
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"})
msg = await client.receive_json()
assert msg["success"]
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input),
):
conversation_id = session.conversation_id
# We should receive 3 events:
# 1. The CREATED event (fired before content is added)
msg = await client.receive_json()
assert msg == {
"type": "event",
"event_type": "created",
"data": {
"chat_log": {
"conversation_id": conversation_id,
"continue_conversation": False,
}
},
}
# 2. The user input content added event
msg = await client.receive_json()
assert msg == {
"type": "event",
"event_type": "content_added",
"data": {
"conversation_id": conversation_id,
"content": {
"content": "Hello",
"role": "user",
},
},
}
# 3. The DELETED event (since no assistant message was added)
msg = await client.receive_json()
assert msg == {
"type": "event",
"event_type": "deleted",
"data": {
"conversation_id": conversation_id,
},
}
async def test_ws_chat_log_subscription_requires_admin(
hass: HomeAssistant,
init_components,
hass_ws_client: WebSocketGenerator,
hass_admin_user: MockUser,
) -> None:
"""Test that chat log subscription requires admin access."""
# Create a non-admin user
hass_admin_user.groups = []
client = await hass_ws_client(hass)
await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"})
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"]["code"] == "unauthorized"