From 754de6f998e213686a04a056fa4c3357467fe37e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 18 Jan 2025 22:33:03 -0500 Subject: [PATCH] Add shared history for conversation agents (#135903) * Add shared history for conversation agents * Remove unused code * Add support for native history items * Store all assistant responses as assistant in history * Add history support to DefaultAgent.async_handle_intents * Make local fallback work * Add default agent history * Add history cleanup * Add tests * ChatHistory -> ChatSession * Address comments * Update snapshots --- .../components/assist_pipeline/pipeline.py | 49 ++- .../components/conversation/__init__.py | 5 + .../components/conversation/default_agent.py | 73 ++-- .../components/conversation/session.py | 327 ++++++++++++++++++ .../openai_conversation/conversation.py | 254 +++++--------- homeassistant/helpers/llm.py | 10 +- .../assist_pipeline/snapshots/test_init.ambr | 14 +- .../snapshots/test_websocket.ambr | 14 +- .../snapshots/test_default_agent.ambr | 38 +- .../conversation/snapshots/test_http.ambr | 24 +- .../conversation/snapshots/test_init.ambr | 26 +- tests/components/conversation/test_session.py | 171 +++++++++ .../snapshots/test_conversation.ambr | 2 +- .../openai_conversation/test_conversation.py | 6 +- tests/syrupy.py | 7 + 15 files changed, 744 insertions(+), 276 deletions(-) create mode 100644 homeassistant/components/conversation/session.py create mode 100644 tests/components/conversation/test_session.py diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index a11b5a657de..9353bbe0007 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -1065,7 +1065,8 @@ class PipelineRun: ) processed_locally = self.intent_agent == conversation.HOME_ASSISTANT_AGENT - conversation_result: conversation.ConversationResult | None = None + agent_id = user_input.agent_id + intent_response: intent.IntentResponse | None = None if user_input.agent_id != conversation.HOME_ASSISTANT_AGENT: # Sentence triggers override conversation agent if ( @@ -1075,14 +1076,12 @@ class PipelineRun: ) ) is not None: # Sentence trigger matched - trigger_response = intent.IntentResponse( + agent_id = "sentence_trigger" + intent_response = intent.IntentResponse( self.pipeline.conversation_language ) - trigger_response.async_set_speech(trigger_response_text) - conversation_result = conversation.ConversationResult( - response=trigger_response, - conversation_id=user_input.conversation_id, - ) + intent_response.async_set_speech(trigger_response_text) + # Try local intents first, if preferred. elif self.pipeline.prefer_local_intents and ( intent_response := await conversation.async_handle_intents( @@ -1090,13 +1089,31 @@ class PipelineRun: ) ): # Local intent matched - conversation_result = conversation.ConversationResult( - response=intent_response, - conversation_id=user_input.conversation_id, - ) + agent_id = conversation.HOME_ASSISTANT_AGENT processed_locally = True - if conversation_result is None: + # It was already handled, create response and add to chat history + if intent_response is not None: + async with conversation.async_get_chat_session( + self.hass, user_input + ) as chat_session: + speech: str = intent_response.speech.get("plain", {}).get( + "speech", "" + ) + chat_session.async_add_message( + conversation.ChatMessage( + role="assistant", + agent_id=agent_id, + content=speech, + native=intent_response, + ) + ) + conversation_result = conversation.ConversationResult( + response=intent_response, + conversation_id=chat_session.conversation_id, + ) + + else: # Fall back to pipeline conversation agent conversation_result = await conversation.async_converse( hass=self.hass, @@ -1107,6 +1124,10 @@ class PipelineRun: language=user_input.language, agent_id=user_input.agent_id, ) + speech = conversation_result.response.speech.get("plain", {}).get( + "speech", "" + ) + except Exception as src_error: _LOGGER.exception("Unexpected error during intent recognition") raise IntentRecognitionError( @@ -1126,10 +1147,6 @@ class PipelineRun: ) ) - speech: str = conversation_result.response.speech.get("plain", {}).get( - "speech", "" - ) - return speech async def prepare_text_to_speech(self) -> None: diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 898b7b2cf4f..9c1db128f15 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -48,20 +48,25 @@ from .default_agent import DefaultAgent, async_setup_default_agent from .entity import ConversationEntity from .http import async_setup as async_setup_conversation_http from .models import AbstractConversationAgent, ConversationInput, ConversationResult +from .session import ChatMessage, ChatSession, ConverseError, async_get_chat_session from .trace import ConversationTraceEventType, async_conversation_trace_append __all__ = [ "DOMAIN", "HOME_ASSISTANT_AGENT", "OLD_HOME_ASSISTANT_AGENT", + "ChatMessage", + "ChatSession", "ConversationEntity", "ConversationEntityFeature", "ConversationInput", "ConversationResult", "ConversationTraceEventType", + "ConverseError", "async_conversation_trace_append", "async_converse", "async_get_agent_info", + "async_get_chat_session", "async_set_agent", "async_setup", "async_unset_agent", diff --git a/homeassistant/components/conversation/default_agent.py b/homeassistant/components/conversation/default_agent.py index 66ffb25fa1a..d4773d50c4b 100644 --- a/homeassistant/components/conversation/default_agent.py +++ b/homeassistant/components/conversation/default_agent.py @@ -62,6 +62,7 @@ from .const import ( ) from .entity import ConversationEntity from .models import ConversationInput, ConversationResult +from .session import ChatMessage, async_get_chat_session from .trace import ConversationTraceEventType, async_conversation_trace_append _LOGGER = logging.getLogger(__name__) @@ -346,35 +347,52 @@ class DefaultAgent(ConversationEntity): async def async_process(self, user_input: ConversationInput) -> ConversationResult: """Process a sentence.""" + response: intent.IntentResponse | None = None + async with async_get_chat_session(self.hass, user_input) as chat_session: + # Check if a trigger matched + if trigger_result := await self.async_recognize_sentence_trigger( + user_input + ): + # Process callbacks and get response + response_text = await self._handle_trigger_result( + trigger_result, user_input + ) - # Check if a trigger matched - if trigger_result := await self.async_recognize_sentence_trigger(user_input): - # Process callbacks and get response - response_text = await self._handle_trigger_result( - trigger_result, user_input + # Convert to conversation result + response = intent.IntentResponse( + language=user_input.language or self.hass.config.language + ) + response.response_type = intent.IntentResponseType.ACTION_DONE + response.async_set_speech(response_text) + + if response is None: + # Match intents + intent_result = await self.async_recognize_intent(user_input) + response = await self._async_process_intent_result( + intent_result, user_input + ) + + speech: str = response.speech.get("plain", {}).get("speech", "") + chat_session.async_add_message( + ChatMessage( + role="assistant", + agent_id=user_input.agent_id, + content=speech, + native=response, + ) ) - # Convert to conversation result - response = intent.IntentResponse( - language=user_input.language or self.hass.config.language + return ConversationResult( + response=response, conversation_id=chat_session.conversation_id ) - response.response_type = intent.IntentResponseType.ACTION_DONE - response.async_set_speech(response_text) - - return ConversationResult(response=response) - - # Match intents - intent_result = await self.async_recognize_intent(user_input) - return await self._async_process_intent_result(intent_result, user_input) async def _async_process_intent_result( self, result: RecognizeResult | None, user_input: ConversationInput, - ) -> ConversationResult: + ) -> intent.IntentResponse: """Process user input with intents.""" language = user_input.language or self.hass.config.language - conversation_id = None # Not supported # Intent match or failure lang_intents = await self.async_get_or_load_intents(language) @@ -386,7 +404,6 @@ class DefaultAgent(ConversationEntity): language, intent.IntentResponseErrorCode.NO_INTENT_MATCH, self._get_error_text(ErrorKey.NO_INTENT, lang_intents), - conversation_id, ) if result.unmatched_entities: @@ -408,7 +425,6 @@ class DefaultAgent(ConversationEntity): self._get_error_text( error_response_type, lang_intents, **error_response_args ), - conversation_id, ) # Will never happen because result will be None when no intents are @@ -461,7 +477,6 @@ class DefaultAgent(ConversationEntity): self._get_error_text( error_response_type, lang_intents, **error_response_args ), - conversation_id, ) except intent.IntentHandleError as err: # Intent was valid and entities matched constraints, but an error @@ -473,7 +488,6 @@ class DefaultAgent(ConversationEntity): self._get_error_text( err.response_key or ErrorKey.HANDLE_ERROR, lang_intents ), - conversation_id, ) except intent.IntentUnexpectedError: _LOGGER.exception("Unexpected intent error") @@ -481,7 +495,6 @@ class DefaultAgent(ConversationEntity): language, intent.IntentResponseErrorCode.UNKNOWN, self._get_error_text(ErrorKey.HANDLE_ERROR, lang_intents), - conversation_id, ) if ( @@ -500,9 +513,7 @@ class DefaultAgent(ConversationEntity): ) intent_response.async_set_speech(speech) - return ConversationResult( - response=intent_response, conversation_id=conversation_id - ) + return intent_response def _recognize( self, @@ -1346,22 +1357,18 @@ class DefaultAgent(ConversationEntity): # No error message on failed match return None - conversation_result = await self._async_process_intent_result( - result, user_input - ) - return conversation_result.response + return await self._async_process_intent_result(result, user_input) def _make_error_result( language: str, error_code: intent.IntentResponseErrorCode, response_text: str, - conversation_id: str | None = None, -) -> ConversationResult: +) -> intent.IntentResponse: """Create conversation result with error code and text.""" response = intent.IntentResponse(language=language) response.async_set_error(error_code, response_text) - return ConversationResult(response, conversation_id) + return response def _get_unmatched_response(result: RecognizeResult) -> tuple[ErrorKey, dict[str, Any]]: diff --git a/homeassistant/components/conversation/session.py b/homeassistant/components/conversation/session.py new file mode 100644 index 00000000000..f9db80afa63 --- /dev/null +++ b/homeassistant/components/conversation/session.py @@ -0,0 +1,327 @@ +"""Conversation history.""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field, replace +from datetime import datetime, timedelta +import logging +from typing import Generic, Literal, TypeVar + +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import ( + CALLBACK_TYPE, + Event, + HassJob, + HassJobType, + HomeAssistant, + callback, +) +from homeassistant.exceptions import HomeAssistantError, TemplateError +from homeassistant.helpers import intent, llm, template +from homeassistant.helpers.event import async_call_later +from homeassistant.util import dt as dt_util, ulid +from homeassistant.util.hass_dict import HassKey + +from .const import DOMAIN +from .models import ConversationInput, ConversationResult + +DATA_CHAT_HISTORY: HassKey["dict[str, ChatSession]"] = HassKey( + "conversation_chat_session" +) +DATA_CHAT_HISTORY_CLEANUP: HassKey["SessionCleanup"] = HassKey( + "conversation_chat_session_cleanup" +) + +LOGGER = logging.getLogger(__name__) +CONVERSATION_TIMEOUT = timedelta(minutes=5) +_NativeT = TypeVar("_NativeT") + + +class SessionCleanup: + """Helper to clean up the history.""" + + unsub: CALLBACK_TYPE | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the history cleanup.""" + self.hass = hass + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop) + self.cleanup_job = HassJob( + self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback + ) + + @callback + def schedule(self) -> None: + """Schedule the cleanup.""" + if self.unsub: + return + self.unsub = async_call_later( + self.hass, + CONVERSATION_TIMEOUT.total_seconds() + 1, + self.cleanup_job, + ) + + @callback + def _on_hass_stop(self, event: Event) -> None: + """Cancel the cleanup on shutdown.""" + if self.unsub: + self.unsub() + self.unsub = None + + @callback + def _cleanup(self, now: datetime) -> None: + """Clean up the history and schedule follow-up if necessary.""" + self.unsub = None + all_history = self.hass.data[DATA_CHAT_HISTORY] + + # We mutate original object because current commands could be + # yielding history based on it. + for conversation_id, history in list(all_history.items()): + if history.last_updated + CONVERSATION_TIMEOUT < now: + del all_history[conversation_id] + + # Still conversations left, check again in timeout time. + if all_history: + self.schedule() + + +@asynccontextmanager +async def async_get_chat_session( + hass: HomeAssistant, + user_input: ConversationInput, +) -> AsyncGenerator["ChatSession"]: + """Return chat session.""" + all_history = hass.data.get(DATA_CHAT_HISTORY) + if all_history is None: + all_history = {} + hass.data[DATA_CHAT_HISTORY] = all_history + hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass) + + history: ChatSession | None = None + + if user_input.conversation_id is None: + conversation_id = ulid.ulid_now() + + elif history := all_history.get(user_input.conversation_id): + conversation_id = user_input.conversation_id + + else: + # Conversation IDs are ULIDs. We generate a new one if not provided. + # If an old OLID is passed in, we will generate a new one to indicate + # a new conversation was started. If the user picks their own, they + # want to track a conversation and we respect it. + try: + ulid.ulid_to_bytes(user_input.conversation_id) + conversation_id = ulid.ulid_now() + except ValueError: + conversation_id = user_input.conversation_id + + if history: + history = replace(history, messages=history.messages.copy()) + else: + history = ChatSession(hass, conversation_id) + + message: ChatMessage = ChatMessage( + role="user", + agent_id=user_input.agent_id, + content=user_input.text, + ) + history.async_add_message(message) + + yield history + + if history.messages[-1] is message: + LOGGER.debug( + "History opened but no assistant message was added, ignoring update" + ) + return + + history.last_updated = dt_util.utcnow() + all_history[conversation_id] = history + hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule() + + +class ConverseError(HomeAssistantError): + """Error during initialization of conversation. + + Will not be stored in the history. + """ + + def __init__( + self, message: str, conversation_id: str, response: intent.IntentResponse + ) -> None: + """Initialize the error.""" + super().__init__(message) + self.conversation_id = conversation_id + self.response = response + + def as_converstation_result(self) -> ConversationResult: + """Return the error as a conversation result.""" + return ConversationResult( + response=self.response, + conversation_id=self.conversation_id, + ) + + +@dataclass +class ChatMessage(Generic[_NativeT]): + """Base class for chat messages. + + When role is native, the content is to be ignored and message + is only meant for storing the native object. + """ + + role: Literal["system", "assistant", "user", "native"] + agent_id: str | None + content: str + native: _NativeT | None = field(default=None) + + # Validate in post-init that if role is native, there is no content and a native object exists + def __post_init__(self) -> None: + """Validate native message.""" + if self.role == "native" and self.native is None: + raise ValueError("Native message must have a native object") + + +@dataclass +class ChatSession(Generic[_NativeT]): + """Class holding all information for a specific conversation.""" + + hass: HomeAssistant + conversation_id: str + user_name: str | None = None + messages: list[ChatMessage[_NativeT]] = field( + default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")] + ) + extra_system_prompt: str | None = None + llm_api: llm.APIInstance | None = None + last_updated: datetime = field(default_factory=dt_util.utcnow) + + @callback + def async_add_message(self, message: ChatMessage[_NativeT]) -> None: + """Process intent.""" + if message.role == "system": + raise ValueError("Cannot add system messages to history") + if message.role != "native" and self.messages[-1].role == message.role: + raise ValueError("Cannot add two assistant or user messages in a row") + + self.messages.append(message) + + @callback + def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]: + """Get messages for a specific agent ID. + + This will filter out any native message tied to other agent IDs. + It can still include assistant/user messages generated by other agents. + """ + return [ + message + for message in self.messages + if message.role != "native" or message.agent_id == agent_id + ] + + async def async_process_llm_message( + self, + conversing_domain: str, + user_input: ConversationInput, + user_llm_hass_api: str | None = None, + user_llm_prompt: str | None = None, + ) -> None: + """Process an incoming message for an LLM.""" + llm_context = llm.LLMContext( + platform=conversing_domain, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=DOMAIN, + device_id=user_input.device_id, + ) + + llm_api: llm.APIInstance | None = None + + if user_llm_hass_api: + try: + llm_api = await llm.async_get_api( + self.hass, + user_llm_hass_api, + llm_context, + ) + except HomeAssistantError as err: + LOGGER.error( + "Error getting LLM API %s for %s: %s", + user_llm_hass_api, + conversing_domain, + err, + ) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + "Error preparing LLM API", + ) + raise ConverseError( + f"Error getting LLM API {user_llm_hass_api}", + conversation_id=self.conversation_id, + response=intent_response, + ) from err + + user_name: str | None = None + + if ( + user_input.context + and user_input.context.user_id + and ( + user := await self.hass.auth.async_get_user(user_input.context.user_id) + ) + ): + user_name = user.name + + try: + prompt_parts = [ + template.Template( + llm.BASE_PROMPT + + (user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT), + self.hass, + ).async_render( + { + "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, + }, + parse_result=False, + ) + ] + + except TemplateError as err: + LOGGER.error("Error rendering prompt: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + "Sorry, I had a problem with my template", + ) + raise ConverseError( + "Error rendering prompt", + conversation_id=self.conversation_id, + response=intent_response, + ) from err + + if llm_api: + prompt_parts.append(llm_api.api_prompt) + + extra_system_prompt = ( + # Take new system prompt if one was given + user_input.extra_system_prompt or self.extra_system_prompt + ) + + if extra_system_prompt: + prompt_parts.append(extra_system_prompt) + + prompt = "\n".join(prompt_parts) + + self.llm_api = llm_api + self.user_name = user_name + self.extra_system_prompt = extra_system_prompt + self.messages[0] = ChatMessage( + role="system", + agent_id=user_input.agent_id, + content=prompt, + ) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index b3f31ae9b47..9a6b61e4c43 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -1,9 +1,8 @@ """Conversation support for OpenAI.""" from collections.abc import Callable -from dataclasses import dataclass, field import json -from typing import Any, Literal +from typing import Any, Literal, cast import openai from openai._types import NOT_GIVEN @@ -12,10 +11,8 @@ from openai.types.chat import ( ChatCompletionMessage, ChatCompletionMessageParam, ChatCompletionMessageToolCallParam, - ChatCompletionSystemMessageParam, ChatCompletionToolMessageParam, ChatCompletionToolParam, - ChatCompletionUserMessageParam, ) from openai.types.chat.chat_completion_message_tool_call_param import Function from openai.types.shared_params import FunctionDefinition @@ -27,10 +24,9 @@ from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError, TemplateError -from homeassistant.helpers import device_registry as dr, intent, llm, template +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.util import ulid from . import OpenAIConfigEntry from .const import ( @@ -74,12 +70,28 @@ def _format_tool( return ChatCompletionToolParam(type="function", function=tool_spec) -@dataclass -class ChatHistory: - """Class holding the chat history.""" - - extra_system_prompt: str | None = None - messages: list[ChatCompletionMessageParam] = field(default_factory=list) +def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessageParam: + """Convert from class to TypedDict.""" + tool_calls: list[ChatCompletionMessageToolCallParam] = [] + if message.tool_calls: + tool_calls = [ + ChatCompletionMessageToolCallParam( + id=tool_call.id, + function=Function( + arguments=tool_call.function.arguments, + name=tool_call.function.name, + ), + type=tool_call.type, + ) + for tool_call in message.tool_calls + ] + param = ChatCompletionAssistantMessageParam( + role=message.role, + content=message.content, + ) + if tool_calls: + param["tool_calls"] = tool_calls + return param class OpenAIConversationEntity( @@ -93,7 +105,6 @@ class OpenAIConversationEntity( def __init__(self, entry: OpenAIConfigEntry) -> None: """Initialize the agent.""" self.entry = entry - self.history: dict[str, ChatHistory] = {} self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, @@ -132,127 +143,56 @@ class OpenAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" + async with conversation.async_get_chat_session( + self.hass, user_input + ) as session: + return await self._async_call_api(user_input, session) + + async def _async_call_api( + self, + user_input: conversation.ConversationInput, + session: conversation.ChatSession[ChatCompletionMessageParam], + ) -> conversation.ConversationResult: + """Call the API.""" options = self.entry.options - intent_response = intent.IntentResponse(language=user_input.language) - llm_api: llm.APIInstance | None = None - tools: list[ChatCompletionToolParam] | None = None - user_name: str | None = None - llm_context = llm.LLMContext( - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ) - - if options.get(CONF_LLM_HASS_API): - try: - llm_api = await llm.async_get_api( - self.hass, - options[CONF_LLM_HASS_API], - llm_context, - ) - except HomeAssistantError as err: - LOGGER.error("Error getting LLM API: %s", err) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - "Error preparing LLM API", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=user_input.conversation_id - ) - tools = [ - _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools - ] - - history: ChatHistory | None = None - - if user_input.conversation_id is None: - conversation_id = ulid.ulid_now() - - elif user_input.conversation_id in self.history: - conversation_id = user_input.conversation_id - history = self.history.get(conversation_id) - - else: - # Conversation IDs are ULIDs. We generate a new one if not provided. - # If an old OLID is passed in, we will generate a new one to indicate - # a new conversation was started. If the user picks their own, they - # want to track a conversation and we respect it. - try: - ulid.ulid_to_bytes(user_input.conversation_id) - conversation_id = ulid.ulid_now() - except ValueError: - conversation_id = user_input.conversation_id - - if history is None: - history = ChatHistory(user_input.extra_system_prompt) - - if ( - user_input.context - and user_input.context.user_id - and ( - user := await self.hass.auth.async_get_user(user_input.context.user_id) - ) - ): - user_name = user.name try: - prompt_parts = [ - template.Template( - llm.BASE_PROMPT - + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), - self.hass, - ).async_render( - { - "ha_name": self.hass.config.location_name, - "user_name": user_name, - "llm_context": llm_context, - }, - parse_result=False, - ) + await session.async_process_llm_message( + DOMAIN, + user_input, + options.get(CONF_LLM_HASS_API), + options.get(CONF_PROMPT), + ) + except conversation.ConverseError as err: + return err.as_converstation_result() + + tools: list[ChatCompletionToolParam] | None = None + if session.llm_api: + tools = [ + _format_tool(tool, session.llm_api.custom_serializer) + for tool in session.llm_api.tools ] - except TemplateError as err: - LOGGER.error("Error rendering prompt: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - "Sorry, I had a problem with my template", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) + messages: list[ChatCompletionMessageParam] = [] + for message in session.async_get_messages(user_input.agent_id): + if message.native is not None and message.agent_id == user_input.agent_id: + messages.append(message.native) + else: + messages.append( + cast( + ChatCompletionMessageParam, + {"role": message.role, "content": message.content}, + ) + ) - if llm_api: - prompt_parts.append(llm_api.api_prompt) - - extra_system_prompt = ( - # Take new system prompt if one was given - user_input.extra_system_prompt or history.extra_system_prompt - ) - - if extra_system_prompt: - prompt_parts.append(extra_system_prompt) - - prompt = "\n".join(prompt_parts) - - # Create a copy of the variable because we attach it to the trace - history = ChatHistory( - extra_system_prompt, - [ - ChatCompletionSystemMessageParam(role="system", content=prompt), - *history.messages[1:], - ChatCompletionUserMessageParam(role="user", content=user_input.text), - ], - ) - - LOGGER.debug("Prompt: %s", history.messages) + LOGGER.debug("Prompt: %s", messages) LOGGER.debug("Tools: %s", tools) trace.async_conversation_trace_append( trace.ConversationTraceEventType.AGENT_DETAIL, - {"messages": history.messages, "tools": llm_api.tools if llm_api else None}, + { + "messages": session.messages, + "tools": session.llm_api.tools if session.llm_api else None, + }, ) client = self.entry.runtime_data @@ -262,12 +202,12 @@ class OpenAIConversationEntity( try: result = await client.chat.completions.create( model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), - messages=history.messages, + messages=messages, tools=tools or NOT_GIVEN, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), - user=conversation_id, + user=session.conversation_id, ) except openai.OpenAIError as err: LOGGER.error("Error talking to OpenAI: %s", err) @@ -277,44 +217,26 @@ class OpenAIConversationEntity( "Sorry, I had a problem talking to OpenAI", ) return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + response=intent_response, conversation_id=session.conversation_id ) LOGGER.debug("Response %s", result) response = result.choices[0].message + messages.append(_message_convert(response)) - def message_convert( - message: ChatCompletionMessage, - ) -> ChatCompletionMessageParam: - """Convert from class to TypedDict.""" - tool_calls: list[ChatCompletionMessageToolCallParam] = [] - if message.tool_calls: - tool_calls = [ - ChatCompletionMessageToolCallParam( - id=tool_call.id, - function=Function( - arguments=tool_call.function.arguments, - name=tool_call.function.name, - ), - type=tool_call.type, - ) - for tool_call in message.tool_calls - ] - param = ChatCompletionAssistantMessageParam( - role=message.role, - content=message.content, - ) - if tool_calls: - param["tool_calls"] = tool_calls - return param + session.async_add_message( + conversation.ChatMessage( + role=response.role, + agent_id=user_input.agent_id, + content=response.content or "", + native=messages[-1], + ), + ) - history.messages.append(message_convert(response)) - tool_calls = response.tool_calls - - if not tool_calls or not llm_api: + if not response.tool_calls or not session.llm_api: break - for tool_call in tool_calls: + for tool_call in response.tool_calls: tool_input = llm.ToolInput( tool_name=tool_call.function.name, tool_args=json.loads(tool_call.function.arguments), @@ -324,27 +246,33 @@ class OpenAIConversationEntity( ) try: - tool_response = await llm_api.async_call_tool(tool_input) + tool_response = await session.llm_api.async_call_tool(tool_input) except (HomeAssistantError, vol.Invalid) as e: tool_response = {"error": type(e).__name__} if str(e): tool_response["error_text"] = str(e) LOGGER.debug("Tool response: %s", tool_response) - history.messages.append( + messages.append( ChatCompletionToolMessageParam( role="tool", tool_call_id=tool_call.id, content=json.dumps(tool_response), ) ) - - self.history[conversation_id] = history + session.async_add_message( + conversation.ChatMessage( + role="native", + agent_id=user_input.agent_id, + content="", + native=messages[-1], + ) + ) intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response.content or "") return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id + response=intent_response, conversation_id=session.conversation_id ) async def _async_entry_update_listener( diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index cb303f4aa65..f66794165f0 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -15,10 +15,6 @@ import voluptuous as vol from voluptuous_openapi import UNSUPPORTED, convert from homeassistant.components.climate import INTENT_GET_TEMPERATURE -from homeassistant.components.conversation import ( - ConversationTraceEventType, - async_conversation_trace_append, -) from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.homeassistant import async_should_expose from homeassistant.components.intent import async_device_supports_timers @@ -171,6 +167,12 @@ class APIInstance: async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: """Call a LLM tool, validate args and return the response.""" + # pylint: disable=import-outside-toplevel + from homeassistant.components.conversation import ( + ConversationTraceEventType, + async_conversation_trace_append, + ) + async_conversation_trace_append( ConversationTraceEventType.TOOL_CALL, {"tool_name": tool_input.tool_name, "tool_args": tool_input.tool_args}, diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index 171014fdc4a..526e1bff151 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -44,7 +44,7 @@ dict({ 'data': dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -135,7 +135,7 @@ dict({ 'data': dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -226,7 +226,7 @@ dict({ 'data': dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -341,7 +341,7 @@ dict({ 'data': dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -446,7 +446,7 @@ dict({ 'data': dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -497,7 +497,7 @@ dict({ 'data': dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -548,7 +548,7 @@ dict({ 'data': dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 41747a50eb6..917a9b654d5 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -42,7 +42,7 @@ # name: test_audio_pipeline.4 dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -125,7 +125,7 @@ # name: test_audio_pipeline_debug.4 dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -220,7 +220,7 @@ # name: test_audio_pipeline_with_enhancements.4 dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -325,7 +325,7 @@ # name: test_audio_pipeline_with_wake_word_no_timeout.6 dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -585,7 +585,7 @@ # name: test_pipeline_empty_tts_output.2 dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -698,7 +698,7 @@ # name: test_text_only_pipeline[extra_msg0].2 dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -744,7 +744,7 @@ # name: test_text_only_pipeline[extra_msg1].2 dict({ 'intent_output': dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), diff --git a/tests/components/conversation/snapshots/test_default_agent.ambr b/tests/components/conversation/snapshots/test_default_agent.ambr index f1e220b10b2..c2b16ea2912 100644 --- a/tests/components/conversation/snapshots/test_default_agent.ambr +++ b/tests/components/conversation/snapshots/test_default_agent.ambr @@ -1,7 +1,7 @@ # serializer version: 1 # name: test_custom_sentences dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -26,7 +26,7 @@ # --- # name: test_custom_sentences.1 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -51,7 +51,7 @@ # --- # name: test_custom_sentences_config dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -76,7 +76,7 @@ # --- # name: test_intent_alias_added_removed dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -106,7 +106,7 @@ # --- # name: test_intent_alias_added_removed.1 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -136,7 +136,7 @@ # --- # name: test_intent_alias_added_removed.2 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -156,7 +156,7 @@ # --- # name: test_intent_conversion_not_expose_new dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -176,7 +176,7 @@ # --- # name: test_intent_conversion_not_expose_new.1 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -206,7 +206,7 @@ # --- # name: test_intent_entity_added_removed dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -236,7 +236,7 @@ # --- # name: test_intent_entity_added_removed.1 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -266,7 +266,7 @@ # --- # name: test_intent_entity_added_removed.2 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -296,7 +296,7 @@ # --- # name: test_intent_entity_added_removed.3 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -316,7 +316,7 @@ # --- # name: test_intent_entity_exposed dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -346,7 +346,7 @@ # --- # name: test_intent_entity_fail_if_unexposed dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -366,7 +366,7 @@ # --- # name: test_intent_entity_remove_custom_name dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -386,7 +386,7 @@ # --- # name: test_intent_entity_remove_custom_name.1 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -416,7 +416,7 @@ # --- # name: test_intent_entity_remove_custom_name.2 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -436,7 +436,7 @@ # --- # name: test_intent_entity_renamed dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -466,7 +466,7 @@ # --- # name: test_intent_entity_renamed.1 dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), diff --git a/tests/components/conversation/snapshots/test_http.ambr b/tests/components/conversation/snapshots/test_http.ambr index 0de575790db..1102a41e6c3 100644 --- a/tests/components/conversation/snapshots/test_http.ambr +++ b/tests/components/conversation/snapshots/test_http.ambr @@ -201,7 +201,7 @@ # --- # name: test_http_api_handle_failure dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -221,7 +221,7 @@ # --- # name: test_http_api_no_match dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -241,7 +241,7 @@ # --- # name: test_http_api_unexpected_failure dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -261,7 +261,7 @@ # --- # name: test_http_processing_intent[None] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -291,7 +291,7 @@ # --- # name: test_http_processing_intent[conversation.home_assistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -321,7 +321,7 @@ # --- # name: test_http_processing_intent[homeassistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -351,7 +351,7 @@ # --- # name: test_ws_api[payload0] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -371,7 +371,7 @@ # --- # name: test_ws_api[payload1] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -391,7 +391,7 @@ # --- # name: test_ws_api[payload2] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -411,7 +411,7 @@ # --- # name: test_ws_api[payload3] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -431,7 +431,7 @@ # --- # name: test_ws_api[payload4] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -451,7 +451,7 @@ # --- # name: test_ws_api[payload5] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), diff --git a/tests/components/conversation/snapshots/test_init.ambr b/tests/components/conversation/snapshots/test_init.ambr index 0327be064d4..911c7043a6d 100644 --- a/tests/components/conversation/snapshots/test_init.ambr +++ b/tests/components/conversation/snapshots/test_init.ambr @@ -1,7 +1,7 @@ # serializer version: 1 # name: test_custom_agent dict({ - 'conversation_id': 'test-conv-id', + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -44,7 +44,7 @@ # --- # name: test_turn_on_intent[None-turn kitchen on-None] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -74,7 +74,7 @@ # --- # name: test_turn_on_intent[None-turn kitchen on-conversation.home_assistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -104,7 +104,7 @@ # --- # name: test_turn_on_intent[None-turn kitchen on-homeassistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -134,7 +134,7 @@ # --- # name: test_turn_on_intent[None-turn on kitchen-None] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -164,7 +164,7 @@ # --- # name: test_turn_on_intent[None-turn on kitchen-conversation.home_assistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -194,7 +194,7 @@ # --- # name: test_turn_on_intent[None-turn on kitchen-homeassistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -224,7 +224,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn kitchen on-None] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -254,7 +254,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn kitchen on-conversation.home_assistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -284,7 +284,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn kitchen on-homeassistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -314,7 +314,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn on kitchen-None] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -344,7 +344,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn on kitchen-conversation.home_assistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), @@ -374,7 +374,7 @@ # --- # name: test_turn_on_intent[my_new_conversation-turn on kitchen-homeassistant] dict({ - 'conversation_id': None, + 'conversation_id': , 'response': dict({ 'card': dict({ }), diff --git a/tests/components/conversation/test_session.py b/tests/components/conversation/test_session.py new file mode 100644 index 00000000000..45cb517528d --- /dev/null +++ b/tests/components/conversation/test_session.py @@ -0,0 +1,171 @@ +"""Test the conversation session.""" + +from collections.abc import Generator +from datetime import timedelta +from unittest.mock import Mock, patch + +import pytest + +from homeassistant.components.conversation import ConversationInput, session +from homeassistant.core import Context, HomeAssistant +from homeassistant.util import dt as dt_util + +from tests.common import async_fire_time_changed + + +@pytest.fixture +def mock_conversation_input(hass: HomeAssistant) -> ConversationInput: + """Return a conversation input instance.""" + return ConversationInput( + text="Hello", + context=Context(), + conversation_id=None, + agent_id="mock-agent-id", + device_id=None, + language="en", + ) + + +@pytest.fixture +def mock_ulid() -> Generator[Mock]: + """Mock the ulid library.""" + with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now: + mock_ulid_now.return_value = "mock-ulid" + yield mock_ulid_now + + +@pytest.mark.parametrize( + ("start_id", "given_id"), + [ + (None, "mock-ulid"), + # This ULID is not known as a session + ("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"), + ("not-a-ulid", "not-a-ulid"), + ], +) +async def test_conversation_id( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, + mock_ulid: Mock, + start_id: str | None, + given_id: str, +) -> None: + """Test conversation ID generation.""" + mock_conversation_input.conversation_id = start_id + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + assert chat_session.conversation_id == given_id + + +async def test_cleanup( + hass: HomeAssistant, + mock_conversation_input: ConversationInput, +) -> None: + """Mock cleanup of the conversation session.""" + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + assert len(chat_session.messages) == 2 + conversation_id = chat_session.conversation_id + + # Generate session entry. + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + # Because we didn't add a message to the session in the last block, + # the conversation was not be persisted and we get a new ID + assert chat_session.conversation_id != conversation_id + conversation_id = chat_session.conversation_id + chat_session.async_add_message( + session.ChatMessage( + role="assistant", + agent_id="mock-agent-id", + content="Hey!", + ) + ) + assert len(chat_session.messages) == 3 + + # Reuse conversation ID to ensure we can chat with same session + mock_conversation_input.conversation_id = conversation_id + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + assert len(chat_session.messages) == 4 + assert chat_session.conversation_id == conversation_id + + async_fire_time_changed( + hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1) + ) + + # It should be cleaned up now and we start a new conversation + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + assert chat_session.conversation_id != conversation_id + assert len(chat_session.messages) == 2 + + +async def test_message_filtering( + hass: HomeAssistant, mock_conversation_input: ConversationInput +) -> None: + """Test filtering of messages.""" + async with session.async_get_chat_session( + hass, mock_conversation_input + ) as chat_session: + messages = chat_session.async_get_messages(agent_id=None) + assert len(messages) == 2 + assert messages[0] == session.ChatMessage( + role="system", + agent_id=None, + content="", + ) + assert messages[1] == session.ChatMessage( + role="user", + agent_id=mock_conversation_input.agent_id, + content=mock_conversation_input.text, + ) + # Cannot add a second user message in a row + with pytest.raises(ValueError): + chat_session.async_add_message( + session.ChatMessage( + role="user", + agent_id="mock-agent-id", + content="Hey!", + ) + ) + + chat_session.async_add_message( + session.ChatMessage( + role="assistant", + agent_id="mock-agent-id", + content="Hey!", + native="assistant-reply-native", + ) + ) + # Different agent, will be filtered out. + chat_session.async_add_message( + session.ChatMessage( + role="native", agent_id="another-mock-agent-id", content="", native=1 + ) + ) + chat_session.async_add_message( + session.ChatMessage( + role="native", agent_id="mock-agent-id", content="", native=1 + ) + ) + + assert len(chat_session.messages) == 5 + + messages = chat_session.async_get_messages(agent_id="mock-agent-id") + assert len(messages) == 4 + + assert messages[2] == session.ChatMessage( + role="assistant", + agent_id="mock-agent-id", + content="Hey!", + native="assistant-reply-native", + ) + assert messages[3] == session.ChatMessage( + role="native", agent_id="mock-agent-id", content="", native=1 + ) diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index eaa3a9de64c..4ef8b8655ee 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -1,7 +1,7 @@ # serializer version: 1 # name: test_unknown_hass_api dict({ - 'conversation_id': None, + 'conversation_id': 'my-conversation-id', 'response': IntentResponse( card=dict({ }), diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 774f60ed666..b89ddcd8921 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -625,7 +625,11 @@ async def test_unknown_hass_api( await hass.async_block_till_done() result = await conversation.async_converse( - hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + hass, + "hello", + "my-conversation-id", + Context(), + agent_id=mock_config_entry.entry_id, ) assert result == snapshot diff --git a/tests/syrupy.py b/tests/syrupy.py index 8812b3c3880..5b1e5faa23d 100644 --- a/tests/syrupy.py +++ b/tests/syrupy.py @@ -109,6 +109,8 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer): serializable_data = cls._serializable_issue_registry_entry(data) elif isinstance(data, dict) and "flow_id" in data and "handler" in data: serializable_data = cls._serializable_flow_result(data) + elif isinstance(data, dict) and set(data) == {"conversation_id", "response"}: + serializable_data = cls._serializable_conversation_result(data) elif isinstance(data, vol.Schema): serializable_data = voluptuous_serialize.convert(data) elif isinstance(data, ConfigEntry): @@ -200,6 +202,11 @@ class HomeAssistantSnapshotSerializer(AmberDataSerializer): """Prepare a Home Assistant flow result for serialization.""" return FlowResultSnapshot(data | {"flow_id": ANY}) + @classmethod + def _serializable_conversation_result(cls, data: dict) -> SerializableData: + """Prepare a Home Assistant conversation result for serialization.""" + return data | {"conversation_id": ANY} + @classmethod def _serializable_issue_registry_entry( cls, data: ir.IssueEntry