mirror of
https://github.com/home-assistant/core.git
synced 2025-10-26 12:09:32 +00:00
Add websocket command to interact with chat logs
This commit is contained in:
@@ -20,10 +20,13 @@ from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from . import trace
|
||||
from .const import ChatLogEventType
|
||||
from .models import ConversationInput, ConversationResult
|
||||
|
||||
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
|
||||
|
||||
SUBSCRIPTIONS: HassKey[list[Callable[[ChatLogEventType, dict[str, Any]], None]]] = (
|
||||
HassKey("conversation_chat_log_subscriptions")
|
||||
)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
|
||||
@@ -31,6 +34,37 @@ current_chat_log: ContextVar[ChatLog | None] = ContextVar(
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_subscribe_chat_logs(
|
||||
hass: HomeAssistant,
|
||||
callback_func: Callable[[ChatLogEventType, dict[str, Any]], None],
|
||||
) -> Callable[[], None]:
|
||||
"""Subscribe to all chat logs."""
|
||||
subscriptions = hass.data.get(SUBSCRIPTIONS)
|
||||
if subscriptions is None:
|
||||
subscriptions = []
|
||||
hass.data[SUBSCRIPTIONS] = subscriptions
|
||||
|
||||
subscriptions.append(callback_func)
|
||||
|
||||
@callback
|
||||
def unsubscribe() -> None:
|
||||
"""Unsubscribe from chat logs."""
|
||||
subscriptions.remove(callback_func)
|
||||
|
||||
return unsubscribe
|
||||
|
||||
|
||||
@callback
|
||||
def _async_notify_subscribers(
|
||||
hass: HomeAssistant, event_type: ChatLogEventType, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Notify subscribers of a chat log event."""
|
||||
if subscriptions := hass.data.get(SUBSCRIPTIONS):
|
||||
for callback_func in subscriptions:
|
||||
callback_func(event_type, data)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def async_get_chat_log(
|
||||
hass: HomeAssistant,
|
||||
@@ -63,6 +97,8 @@ def async_get_chat_log(
|
||||
all_chat_logs = {}
|
||||
hass.data[DATA_CHAT_LOGS] = all_chat_logs
|
||||
|
||||
is_new_log = session.conversation_id not in all_chat_logs
|
||||
|
||||
if chat_log := all_chat_logs.get(session.conversation_id):
|
||||
chat_log = replace(chat_log, content=chat_log.content.copy())
|
||||
else:
|
||||
@@ -71,6 +107,12 @@ def async_get_chat_log(
|
||||
if chat_log_delta_listener:
|
||||
chat_log.delta_listener = chat_log_delta_listener
|
||||
|
||||
# Fire CREATED event for new chat logs before any content is added
|
||||
if is_new_log:
|
||||
_async_notify_subscribers(
|
||||
hass, ChatLogEventType.CREATED, {"chat_log": chat_log.as_dict()}
|
||||
)
|
||||
|
||||
if user_input is not None:
|
||||
chat_log.async_add_user_content(UserContent(content=user_input.text))
|
||||
|
||||
@@ -84,14 +126,26 @@ def async_get_chat_log(
|
||||
LOGGER.debug(
|
||||
"Chat Log opened but no assistant message was added, ignoring update"
|
||||
)
|
||||
# If this was a new log but nothing was added, fire DELETED to clean up
|
||||
if is_new_log:
|
||||
_async_notify_subscribers(
|
||||
hass,
|
||||
ChatLogEventType.DELETED,
|
||||
{"conversation_id": session.conversation_id},
|
||||
)
|
||||
return
|
||||
|
||||
if session.conversation_id not in all_chat_logs:
|
||||
if is_new_log:
|
||||
|
||||
@callback
|
||||
def do_cleanup() -> None:
|
||||
"""Handle cleanup."""
|
||||
all_chat_logs.pop(session.conversation_id)
|
||||
_async_notify_subscribers(
|
||||
hass,
|
||||
ChatLogEventType.DELETED,
|
||||
{"conversation_id": session.conversation_id},
|
||||
)
|
||||
|
||||
session.async_on_cleanup(do_cleanup)
|
||||
|
||||
@@ -100,6 +154,13 @@ def async_get_chat_log(
|
||||
|
||||
all_chat_logs[session.conversation_id] = chat_log
|
||||
|
||||
# For new logs, CREATED was already fired before content was added
|
||||
# For existing logs, fire UPDATED
|
||||
if not is_new_log:
|
||||
_async_notify_subscribers(
|
||||
hass, ChatLogEventType.UPDATED, {"chat_log": chat_log.as_dict()}
|
||||
)
|
||||
|
||||
|
||||
class ConverseError(HomeAssistantError):
|
||||
"""Error during initialization of conversation.
|
||||
@@ -130,6 +191,10 @@ class SystemContent:
|
||||
role: Literal["system"] = field(init=False, default="system")
|
||||
content: str
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dictionary representation of the content."""
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UserContent:
|
||||
@@ -139,6 +204,15 @@ class UserContent:
|
||||
content: str
|
||||
attachments: list[Attachment] | None = field(default=None)
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dictionary representation of the content."""
|
||||
result: dict[str, Any] = {"role": self.role, "content": self.content}
|
||||
if self.attachments:
|
||||
result["attachments"] = [
|
||||
attachment.as_dict() for attachment in self.attachments
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Attachment:
|
||||
@@ -153,6 +227,14 @@ class Attachment:
|
||||
path: Path
|
||||
"""Path to the attachment on disk."""
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dictionary representation of the attachment."""
|
||||
return {
|
||||
"media_content_id": self.media_content_id,
|
||||
"mime_type": self.mime_type,
|
||||
"path": str(self.path),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssistantContent:
|
||||
@@ -165,6 +247,17 @@ class AssistantContent:
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
native: Any = None
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dictionary representation of the content."""
|
||||
result: dict[str, Any] = {"role": self.role, "agent_id": self.agent_id}
|
||||
if self.content:
|
||||
result["content"] = self.content
|
||||
if self.thinking_content:
|
||||
result["thinking_content"] = self.thinking_content
|
||||
if self.tool_calls:
|
||||
result["tool_calls"] = self.tool_calls
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultContent:
|
||||
@@ -176,6 +269,16 @@ class ToolResultContent:
|
||||
tool_name: str
|
||||
tool_result: JsonObjectType
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dictionary representation of the content."""
|
||||
return {
|
||||
"role": self.role,
|
||||
"agent_id": self.agent_id,
|
||||
"tool_call_id": self.tool_call_id,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_result": self.tool_result,
|
||||
}
|
||||
|
||||
|
||||
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
|
||||
|
||||
@@ -211,6 +314,13 @@ class ChatLog:
|
||||
delta_listener: Callable[[ChatLog, dict], None] | None = None
|
||||
llm_input_provided_index = 0
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dictionary representation of the chat log."""
|
||||
return {
|
||||
"conversation_id": self.conversation_id,
|
||||
"continue_conversation": self.continue_conversation,
|
||||
}
|
||||
|
||||
@property
|
||||
def continue_conversation(self) -> bool:
|
||||
"""Return whether the conversation should continue."""
|
||||
@@ -241,6 +351,11 @@ class ChatLog:
|
||||
"""Add user content to the log."""
|
||||
LOGGER.debug("Adding user content: %s", content)
|
||||
self.content.append(content)
|
||||
_async_notify_subscribers(
|
||||
self.hass,
|
||||
ChatLogEventType.CONTENT_ADDED,
|
||||
{"conversation_id": self.conversation_id, "content": content.as_dict()},
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_add_assistant_content_without_tools(
|
||||
@@ -259,6 +374,11 @@ class ChatLog:
|
||||
):
|
||||
raise ValueError("Non-external tool calls not allowed")
|
||||
self.content.append(content)
|
||||
_async_notify_subscribers(
|
||||
self.hass,
|
||||
ChatLogEventType.CONTENT_ADDED,
|
||||
{"conversation_id": self.conversation_id, "content": content.as_dict()},
|
||||
)
|
||||
|
||||
async def async_add_assistant_content(
|
||||
self,
|
||||
@@ -317,6 +437,14 @@ class ChatLog:
|
||||
tool_result=tool_result,
|
||||
)
|
||||
self.content.append(response_content)
|
||||
_async_notify_subscribers(
|
||||
self.hass,
|
||||
ChatLogEventType.CONTENT_ADDED,
|
||||
{
|
||||
"conversation_id": self.conversation_id,
|
||||
"content": response_content.as_dict(),
|
||||
},
|
||||
)
|
||||
yield response_content
|
||||
|
||||
async def async_add_delta_content_stream(
|
||||
@@ -593,6 +721,11 @@ class ChatLog:
|
||||
self.llm_api = llm_api
|
||||
self.extra_system_prompt = extra_system_prompt
|
||||
self.content[0] = SystemContent(content=prompt)
|
||||
_async_notify_subscribers(
|
||||
self.hass,
|
||||
ChatLogEventType.UPDATED,
|
||||
{"conversation_id": self.conversation_id, "chat_log": self.as_dict()},
|
||||
)
|
||||
|
||||
LOGGER.debug("Prompt: %s", self.content)
|
||||
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntFlag
|
||||
from enum import IntFlag, StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
@@ -30,3 +30,12 @@ class ConversationEntityFeature(IntFlag):
|
||||
"""Supported features of the conversation entity."""
|
||||
|
||||
CONTROL = 1
|
||||
|
||||
|
||||
class ChatLogEventType(StrEnum):
|
||||
"""Chat log event type."""
|
||||
|
||||
CREATED = "created"
|
||||
UPDATED = "updated"
|
||||
DELETED = "deleted"
|
||||
CONTENT_ADDED = "content_added"
|
||||
|
||||
@@ -20,6 +20,7 @@ from .agent_manager import (
|
||||
async_get_agent,
|
||||
get_agent_manager,
|
||||
)
|
||||
from .chat_log import async_subscribe_chat_logs
|
||||
from .const import DATA_COMPONENT
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput
|
||||
@@ -35,6 +36,7 @@ def async_setup(hass: HomeAssistant) -> None:
|
||||
websocket_api.async_register_command(hass, websocket_list_sentences)
|
||||
websocket_api.async_register_command(hass, websocket_hass_agent_debug)
|
||||
websocket_api.async_register_command(hass, websocket_hass_agent_language_scores)
|
||||
websocket_api.async_register_command(hass, websocket_subscribe_chat_logs)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
@@ -265,3 +267,28 @@ class ConversationProcessView(http.HomeAssistantView):
|
||||
)
|
||||
|
||||
return self.json(result.as_dict())
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "conversation/chat_log/subscribe",
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
def websocket_subscribe_chat_logs(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Subscribe to all chat logs."""
|
||||
|
||||
@callback
|
||||
def forward_events(event_type: str, data: dict) -> None:
|
||||
"""Forward chat log events to websocket connection."""
|
||||
connection.send_message(
|
||||
{"type": "event", "event_type": event_type, "data": data}
|
||||
)
|
||||
|
||||
unsubscribe = async_subscribe_chat_logs(hass, forward_events)
|
||||
connection.subscriptions[msg["id"]] = unsubscribe
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
|
||||
DOMAIN = ha.DOMAIN
|
||||
|
||||
DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entites")
|
||||
DATA_EXPOSED_ENTITIES: HassKey[ExposedEntities] = HassKey(f"{DOMAIN}.exposed_entities")
|
||||
DATA_STOP_HANDLER = f"{DOMAIN}.stop_handler"
|
||||
|
||||
SERVICE_HOMEASSISTANT_STOP: Final = "stop"
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from dataclasses import asdict
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -16,7 +18,12 @@ from homeassistant.components.conversation import (
|
||||
UserContent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from homeassistant.components.conversation.chat_log import DATA_CHAT_LOGS
|
||||
from homeassistant.components.conversation.chat_log import (
|
||||
DATA_CHAT_LOGS,
|
||||
Attachment,
|
||||
ChatLogEventType,
|
||||
async_subscribe_chat_logs,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, llm
|
||||
@@ -841,3 +848,168 @@ async def test_chat_log_continue_conversation(
|
||||
)
|
||||
)
|
||||
assert chat_log.continue_conversation is True
|
||||
|
||||
|
||||
async def test_chat_log_subscription(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test comprehensive chat log subscription functionality."""
|
||||
|
||||
# Track all events received
|
||||
received_events = []
|
||||
|
||||
def event_callback(event_type: ChatLogEventType, data: dict[str, Any]) -> None:
|
||||
"""Track received events."""
|
||||
received_events.append((event_type, data))
|
||||
|
||||
# Subscribe to chat log events
|
||||
unsubscribe = async_subscribe_chat_logs(hass, event_callback)
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
conversation_id = session.conversation_id
|
||||
|
||||
# Test adding different types of content and verify events are sent
|
||||
chat_log.async_add_user_content(
|
||||
UserContent(
|
||||
content="Check this image",
|
||||
attachments=[
|
||||
Attachment(
|
||||
mime_type="image/jpeg",
|
||||
media_content_id="media-source://bla",
|
||||
path=Path("test_image.jpg"),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
# Check user content with attachments event
|
||||
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
|
||||
user_event = received_events[-1][1]["content"]
|
||||
assert user_event["content"] == "Check this image"
|
||||
assert len(user_event["attachments"]) == 1
|
||||
assert user_event["attachments"][0]["mime_type"] == "image/jpeg"
|
||||
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="test-agent", content="Hello! How can I help you?"
|
||||
)
|
||||
)
|
||||
# Check basic assistant content event
|
||||
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
|
||||
basic_event = received_events[-1][1]["content"]
|
||||
assert basic_event["content"] == "Hello! How can I help you?"
|
||||
assert basic_event["agent_id"] == "test-agent"
|
||||
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="test-agent",
|
||||
content="Let me think about that...",
|
||||
thinking_content="I need to analyze the user's request carefully.",
|
||||
)
|
||||
)
|
||||
# Check assistant content with thinking event
|
||||
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
|
||||
thinking_event = received_events[-1][1]["content"]
|
||||
assert (
|
||||
thinking_event["thinking_content"]
|
||||
== "I need to analyze the user's request carefully."
|
||||
)
|
||||
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="test-agent",
|
||||
content="Here's some data:",
|
||||
native={"type": "chart", "data": [1, 2, 3, 4, 5]},
|
||||
)
|
||||
)
|
||||
# Check assistant content with native event
|
||||
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
|
||||
native_event = received_events[-1][1]["content"]
|
||||
assert native_event["content"] == "Here's some data:"
|
||||
assert native_event["agent_id"] == "test-agent"
|
||||
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
ToolResultContent(
|
||||
agent_id="test-agent",
|
||||
tool_call_id="test-tool-call-123",
|
||||
tool_name="test_tool",
|
||||
tool_result="Tool execution completed successfully",
|
||||
)
|
||||
)
|
||||
# Check tool result content event
|
||||
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
|
||||
tool_result_event = received_events[-1][1]["content"]
|
||||
assert tool_result_event["tool_name"] == "test_tool"
|
||||
assert (
|
||||
tool_result_event["tool_result"] == "Tool execution completed successfully"
|
||||
)
|
||||
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="test-agent",
|
||||
content="I'll call an external service",
|
||||
tool_calls=[
|
||||
llm.ToolInput(
|
||||
id="external-tool-call-123",
|
||||
tool_name="external_api_call",
|
||||
tool_args={"endpoint": "https://api.example.com/data"},
|
||||
external=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
# Check external tool call event
|
||||
assert received_events[-1][0] == ChatLogEventType.CONTENT_ADDED
|
||||
external_tool_event = received_events[-1][1]["content"]
|
||||
assert len(external_tool_event["tool_calls"]) == 1
|
||||
assert external_tool_event["tool_calls"][0].tool_name == "external_api_call"
|
||||
|
||||
# Verify we received the expected events
|
||||
# Should have: 1 CREATED event + 7 CONTENT_ADDED events
|
||||
assert len(received_events) == 8
|
||||
|
||||
# Check the first event is CREATED
|
||||
assert received_events[0][0] == ChatLogEventType.CREATED
|
||||
assert received_events[0][1]["chat_log"]["conversation_id"] == conversation_id
|
||||
|
||||
# Check the second event is CONTENT_ADDED (from mock_conversation_input)
|
||||
assert received_events[1][0] == ChatLogEventType.CONTENT_ADDED
|
||||
assert received_events[1][1]["conversation_id"] == conversation_id
|
||||
|
||||
# Test cleanup functionality
|
||||
assert conversation_id in hass.data[chat_session.DATA_CHAT_SESSION]
|
||||
|
||||
# Set the last updated to be older than the timeout
|
||||
hass.data[chat_session.DATA_CHAT_SESSION][conversation_id].last_updated = (
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT
|
||||
)
|
||||
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1),
|
||||
)
|
||||
|
||||
# Check that DELETED event was sent
|
||||
assert received_events[-1][0] == ChatLogEventType.DELETED
|
||||
assert received_events[-1][1]["conversation_id"] == conversation_id
|
||||
|
||||
# Test that unsubscribing stops receiving events
|
||||
events_before_unsubscribe = len(received_events)
|
||||
unsubscribe()
|
||||
|
||||
# Create a new session and add content - should not receive events
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session2,
|
||||
async_get_chat_log(hass, session2, mock_conversation_input) as chat_log2,
|
||||
):
|
||||
chat_log2.async_add_assistant_content_without_tools(
|
||||
AssistantContent(
|
||||
agent_id="test-agent", content="This should not be received"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify no new events were received after unsubscribing
|
||||
assert len(received_events) == events_before_unsubscribe
|
||||
|
||||
@@ -7,17 +7,26 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.conversation import async_get_agent
|
||||
from homeassistant.components.conversation import (
|
||||
ConversationInput,
|
||||
async_get_agent,
|
||||
async_get_chat_log,
|
||||
)
|
||||
from homeassistant.components.conversation.const import HOME_ASSISTANT_AGENT
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import area_registry as ar, entity_registry as er, intent
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
chat_session,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MockAgent
|
||||
|
||||
from tests.common import async_mock_service
|
||||
from tests.common import MockUser, async_mock_service
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
AGENT_ID_OPTIONS = [
|
||||
@@ -590,3 +599,79 @@ async def test_ws_hass_language_scores_with_filter(
|
||||
# GB English should be preferred
|
||||
result = msg["result"]
|
||||
assert result["preferred_language"] == "en-GB"
|
||||
|
||||
|
||||
async def test_ws_chat_log_subscription(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_conversation_input: ConversationInput,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test that we can subscribe to chat logs."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input),
|
||||
):
|
||||
conversation_id = session.conversation_id
|
||||
|
||||
# We should receive 3 events:
|
||||
# 1. The CREATED event (fired before content is added)
|
||||
msg = await client.receive_json()
|
||||
assert msg == {
|
||||
"type": "event",
|
||||
"event_type": "created",
|
||||
"data": {
|
||||
"chat_log": {
|
||||
"conversation_id": conversation_id,
|
||||
"continue_conversation": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# 2. The user input content added event
|
||||
msg = await client.receive_json()
|
||||
assert msg == {
|
||||
"type": "event",
|
||||
"event_type": "content_added",
|
||||
"data": {
|
||||
"conversation_id": conversation_id,
|
||||
"content": {
|
||||
"content": "Hello",
|
||||
"role": "user",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# 3. The DELETED event (since no assistant message was added)
|
||||
msg = await client.receive_json()
|
||||
assert msg == {
|
||||
"type": "event",
|
||||
"event_type": "deleted",
|
||||
"data": {
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def test_ws_chat_log_subscription_requires_admin(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
hass_admin_user: MockUser,
|
||||
) -> None:
|
||||
"""Test that chat log subscription requires admin access."""
|
||||
# Create a non-admin user
|
||||
hass_admin_user.groups = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id({"type": "conversation/chat_log/subscribe"})
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "unauthorized"
|
||||
|
||||
Reference in New Issue
Block a user