From 89e2c57da686f4836f4cb031e4943df662e39571 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 25 May 2024 11:16:51 -0700 Subject: [PATCH] Add conversation agent debug tracing (#118124) * Add debug tracing for conversation agents * Minor cleanup --- .../components/conversation/agent_manager.py | 30 +++-- .../components/conversation/trace.py | 118 ++++++++++++++++++ .../conversation.py | 4 + .../components/ollama/conversation.py | 6 + .../openai_conversation/conversation.py | 4 + homeassistant/helpers/llm.py | 10 +- tests/components/conversation/test_entity.py | 7 ++ tests/components/conversation/test_trace.py | 80 ++++++++++++ .../test_conversation.py | 15 +++ tests/components/ollama/test_conversation.py | 14 +++ .../openai_conversation/test_conversation.py | 15 +++ 11 files changed, 294 insertions(+), 9 deletions(-) create mode 100644 homeassistant/components/conversation/trace.py create mode 100644 tests/components/conversation/test_trace.py diff --git a/homeassistant/components/conversation/agent_manager.py b/homeassistant/components/conversation/agent_manager.py index 9f31ccd6c62..aa8b7644900 100644 --- a/homeassistant/components/conversation/agent_manager.py +++ b/homeassistant/components/conversation/agent_manager.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import logging from typing import Any @@ -20,6 +21,11 @@ from .models import ( ConversationInput, ConversationResult, ) +from .trace import ( + ConversationTraceEvent, + ConversationTraceEventType, + async_conversation_trace, +) _LOGGER = logging.getLogger(__name__) @@ -84,15 +90,23 @@ async def async_converse( language = hass.config.language _LOGGER.debug("Processing in %s: %s", language, text) - return await method( - ConversationInput( - text=text, - context=context, - conversation_id=conversation_id, - device_id=device_id, - language=language, - ) + conversation_input = ConversationInput( + text=text, + context=context, + conversation_id=conversation_id, + device_id=device_id, + language=language, ) + with async_conversation_trace() as trace: + trace.add_event( + ConversationTraceEvent( + ConversationTraceEventType.ASYNC_PROCESS, + dataclasses.asdict(conversation_input), + ) + ) + result = await method(conversation_input) + trace.set_result(**result.as_dict()) + return result class AgentManager: diff --git a/homeassistant/components/conversation/trace.py b/homeassistant/components/conversation/trace.py new file mode 100644 index 00000000000..0bd2fe8ed5b --- /dev/null +++ b/homeassistant/components/conversation/trace.py @@ -0,0 +1,118 @@ +"""Debug traces for conversation.""" + +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import asdict, dataclass, field +import enum +from typing import Any + +from homeassistant.util import dt as dt_util, ulid as ulid_util +from homeassistant.util.limited_size_dict import LimitedSizeDict + +STORED_TRACES = 3 + + +class ConversationTraceEventType(enum.StrEnum): + """Type of an event emitted during a conversation.""" + + ASYNC_PROCESS = "async_process" + """The conversation is started from user input.""" + + AGENT_DETAIL = "agent_detail" + """Event detail added by a conversation agent.""" + + LLM_TOOL_CALL = "llm_tool_call" + """An LLM Tool call""" + + +@dataclass(frozen=True) +class ConversationTraceEvent: + """Event emitted during a conversation.""" + + event_type: ConversationTraceEventType + data: dict[str, Any] | None = None + timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat()) + + +class ConversationTrace: + """Stores debug data related to a conversation.""" + + def __init__(self) -> None: + """Initialize ConversationTrace.""" + self._trace_id = ulid_util.ulid_now() + self._events: list[ConversationTraceEvent] = [] + self._error: Exception | None = None + self._result: dict[str, Any] = {} + + @property + def trace_id(self) -> str: + """Identifier for this trace.""" + return self._trace_id + + def add_event(self, event: ConversationTraceEvent) -> None: + """Add an event to the trace.""" + self._events.append(event) + + def set_error(self, ex: Exception) -> None: + """Set error.""" + self._error = ex + + def set_result(self, **kwargs: Any) -> None: + """Set result.""" + self._result = {**kwargs} + + def as_dict(self) -> dict[str, Any]: + """Return dictionary version of this ConversationTrace.""" + result: dict[str, Any] = { + "id": self._trace_id, + "events": [asdict(event) for event in self._events], + } + if self._error is not None: + result["error"] = str(self._error) or self._error.__class__.__name__ + if self._result is not None: + result["result"] = self._result + return result + + +_current_trace: ContextVar[ConversationTrace | None] = ContextVar( + "current_trace", default=None +) +_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict( + size_limit=STORED_TRACES +) + + +def async_conversation_trace_append( + event_type: ConversationTraceEventType, event_data: dict[str, Any] +) -> None: + """Append a ConversationTraceEvent to the current active trace.""" + trace = _current_trace.get() + if not trace: + return + trace.add_event(ConversationTraceEvent(event_type, event_data)) + + +@contextmanager +def async_conversation_trace() -> Generator[ConversationTrace, None]: + """Create a new active ConversationTrace.""" + trace = ConversationTrace() + token = _current_trace.set(trace) + _recent_traces[trace.trace_id] = trace + try: + yield trace + except Exception as ex: + trace.set_error(ex) + raise + finally: + _current_trace.reset(token) + + +def async_get_traces() -> list[ConversationTrace]: + """Get the most recent traces.""" + return list(_recent_traces.values()) + + +def async_clear_traces() -> None: + """Clear all traces.""" + _recent_traces.clear() diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 8a6a761d549..f84bd81f80c 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -12,6 +12,7 @@ import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant @@ -250,6 +251,9 @@ class GoogleGenerativeAIConversationEntity( messages[1] = {"role": "model", "parts": "Ok"} LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} + ) chat = model.start_chat(history=messages) chat_request = user_input.text diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index cbec719780a..fa7a3c3797e 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -9,6 +9,7 @@ from typing import Literal import ollama from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry from homeassistant.const import MATCH_ALL @@ -138,6 +139,11 @@ class OllamaConversationEntity( ollama.Message(role=MessageRole.USER.value, content=user_input.text) ) + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, + {"messages": message_history.messages}, + ) + # Get response try: response = await client.chat( diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 2bd21429d9f..be3b8ea9126 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -8,6 +8,7 @@ import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant @@ -169,6 +170,9 @@ class OpenAIConversationEntity( messages.append({"role": "user", "content": user_input.text}) LOGGER.debug("Prompt: %s", messages) + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} + ) client = self.hass.data[DOMAIN][self.entry.entry_id] diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index cde644a7641..1ffc2880547 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -3,12 +3,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any import voluptuous as vol from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE +from homeassistant.components.conversation.trace import ( + ConversationTraceEventType, + async_conversation_trace_append, +) from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -116,6 +120,10 @@ class API(ABC): async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: """Call a LLM tool, validate args and return the response.""" + async_conversation_trace_append( + ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input) + ) + for tool in self.async_get_tools(): if tool.name == tool_input.tool_name: break diff --git a/tests/components/conversation/test_entity.py b/tests/components/conversation/test_entity.py index c84f94c4aa4..109c0ed361f 100644 --- a/tests/components/conversation/test_entity.py +++ b/tests/components/conversation/test_entity.py @@ -2,7 +2,9 @@ from unittest.mock import patch +from homeassistant.components import conversation from homeassistant.core import Context, HomeAssistant, State +from homeassistant.helpers import intent from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None: ) as mock_process, patch("homeassistant.util.dt.utcnow", return_value=now), ): + intent_response = intent.IntentResponse(language="en") + intent_response.async_set_speech("response text") + mock_process.return_value = conversation.ConversationResult( + response=intent_response, + ) await hass.services.async_call( "conversation", "process", diff --git a/tests/components/conversation/test_trace.py b/tests/components/conversation/test_trace.py new file mode 100644 index 00000000000..c586eb8865d --- /dev/null +++ b/tests/components/conversation/test_trace.py @@ -0,0 +1,80 @@ +"""Test for the conversation traces.""" + +from unittest.mock import patch + +import pytest + +from homeassistant.components import conversation +from homeassistant.components.conversation import trace +from homeassistant.core import Context, HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.setup import async_setup_component + + +@pytest.fixture +async def init_components(hass: HomeAssistant): + """Initialize relevant components with empty configs.""" + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "conversation", {}) + assert await async_setup_component(hass, "intent", {}) + + +async def test_converation_trace( + hass: HomeAssistant, + init_components: None, + sl_setup: None, +) -> None: + """Test tracing a conversation.""" + await conversation.async_converse( + hass, "add apples to my shopping list", None, Context() + ) + + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + assert last_trace.get("events") + assert len(last_trace.get("events")) == 1 + trace_event = last_trace["events"][0] + assert ( + trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS + ) + assert trace_event.get("data") + assert trace_event["data"].get("text") == "add apples to my shopping list" + assert last_trace.get("result") + assert ( + last_trace["result"] + .get("response", {}) + .get("speech", {}) + .get("plain", {}) + .get("speech") + == "Added apples" + ) + + +async def test_converation_trace_error( + hass: HomeAssistant, + init_components: None, + sl_setup: None, +) -> None: + """Test tracing a conversation.""" + with ( + patch( + "homeassistant.components.conversation.default_agent.DefaultAgent.async_process", + side_effect=HomeAssistantError("Failed to talk to agent"), + ), + pytest.raises(HomeAssistantError), + ): + await conversation.async_converse( + hass, "add apples to my shopping list", None, Context() + ) + + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + assert last_trace.get("events") + assert len(last_trace.get("events")) == 1 + trace_event = last_trace["events"][0] + assert ( + trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS + ) + assert last_trace.get("error") == "Failed to talk to agent" diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index b31d9442a43..4c208c240b8 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -9,6 +9,7 @@ from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation +from homeassistant.components.conversation import trace from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -285,6 +286,20 @@ async def test_function_call( ), ) + # Test conversating tracing + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + trace_events = last_trace.get("events", []) + assert [event["event_type"] for event in trace_events] == [ + trace.ConversationTraceEventType.ASYNC_PROCESS, + trace.ConversationTraceEventType.AGENT_DETAIL, + trace.ConversationTraceEventType.LLM_TOOL_CALL, + ] + # AGENT_DETAIL event contains the raw prompt passed to the model + detail_event = trace_events[1] + assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"] + @patch( "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index 080d0d34f2d..b6f0be3c414 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -6,6 +6,7 @@ from ollama import Message, ResponseError import pytest from homeassistant.components import conversation, ollama +from homeassistant.components.conversation import trace from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL from homeassistant.core import Context, HomeAssistant @@ -110,6 +111,19 @@ async def test_chat( ), result assert result.response.speech["plain"]["speech"] == "test response" + # Test Conversation tracing + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + trace_events = last_trace.get("events", []) + assert [event["event_type"] for event in trace_events] == [ + trace.ConversationTraceEventType.ASYNC_PROCESS, + trace.ConversationTraceEventType.AGENT_DETAIL, + ] + # AGENT_DETAIL event contains the raw prompt passed to the model + detail_event = trace_events[1] + assert "The current time is" in detail_event["data"]["messages"][0]["content"] + async def test_message_history_trimming( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 319295374a7..3fa5c307b6d 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation +from homeassistant.components.conversation import trace from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -200,6 +201,20 @@ async def test_function_call( ), ) + # Test Conversation tracing + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + trace_events = last_trace.get("events", []) + assert [event["event_type"] for event in trace_events] == [ + trace.ConversationTraceEventType.ASYNC_PROCESS, + trace.ConversationTraceEventType.AGENT_DETAIL, + trace.ConversationTraceEventType.LLM_TOOL_CALL, + ] + # AGENT_DETAIL event contains the raw prompt passed to the model + detail_event = trace_events[1] + assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] + @patch( "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"