From d13c14eedb5c5757d62238b2f6aa29a0a3fa4f26 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 6 Jan 2025 17:01:13 -0500 Subject: [PATCH] Add support for extra_system_prompt to OpenAI (#134931) --- .../openai_conversation/conversation.py | 52 ++++++--- .../openai_conversation/test_conversation.py | 101 ++++++++++++++++++ 2 files changed, 138 insertions(+), 15 deletions(-) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 9c73766c8d4..b3f31ae9b47 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -1,6 +1,7 @@ """Conversation support for OpenAI.""" from collections.abc import Callable +from dataclasses import dataclass, field import json from typing import Any, Literal @@ -73,6 +74,14 @@ def _format_tool( return ChatCompletionToolParam(type="function", function=tool_spec) +@dataclass +class ChatHistory: + """Class holding the chat history.""" + + extra_system_prompt: str | None = None + messages: list[ChatCompletionMessageParam] = field(default_factory=list) + + class OpenAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -84,7 +93,7 @@ class OpenAIConversationEntity( def __init__(self, entry: OpenAIConfigEntry) -> None: """Initialize the agent.""" self.entry = entry - self.history: dict[str, list[ChatCompletionMessageParam]] = {} + self.history: dict[str, ChatHistory] = {} self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo( identifiers={(DOMAIN, entry.entry_id)}, @@ -157,13 +166,14 @@ class OpenAIConversationEntity( _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools ] + history: ChatHistory | None = None + if user_input.conversation_id is None: conversation_id = ulid.ulid_now() - messages = [] elif user_input.conversation_id in self.history: conversation_id = user_input.conversation_id - messages = self.history[conversation_id] + history = self.history.get(conversation_id) else: # Conversation IDs are ULIDs. We generate a new one if not provided. @@ -176,7 +186,8 @@ class OpenAIConversationEntity( except ValueError: conversation_id = user_input.conversation_id - messages = [] + if history is None: + history = ChatHistory(user_input.extra_system_prompt) if ( user_input.context @@ -217,20 +228,31 @@ class OpenAIConversationEntity( if llm_api: prompt_parts.append(llm_api.api_prompt) + extra_system_prompt = ( + # Take new system prompt if one was given + user_input.extra_system_prompt or history.extra_system_prompt + ) + + if extra_system_prompt: + prompt_parts.append(extra_system_prompt) + prompt = "\n".join(prompt_parts) # Create a copy of the variable because we attach it to the trace - messages = [ - ChatCompletionSystemMessageParam(role="system", content=prompt), - *messages[1:], - ChatCompletionUserMessageParam(role="user", content=user_input.text), - ] + history = ChatHistory( + extra_system_prompt, + [ + ChatCompletionSystemMessageParam(role="system", content=prompt), + *history.messages[1:], + ChatCompletionUserMessageParam(role="user", content=user_input.text), + ], + ) - LOGGER.debug("Prompt: %s", messages) + LOGGER.debug("Prompt: %s", history.messages) LOGGER.debug("Tools: %s", tools) trace.async_conversation_trace_append( trace.ConversationTraceEventType.AGENT_DETAIL, - {"messages": messages, "tools": llm_api.tools if llm_api else None}, + {"messages": history.messages, "tools": llm_api.tools if llm_api else None}, ) client = self.entry.runtime_data @@ -240,7 +262,7 @@ class OpenAIConversationEntity( try: result = await client.chat.completions.create( model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), - messages=messages, + messages=history.messages, tools=tools or NOT_GIVEN, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), @@ -286,7 +308,7 @@ class OpenAIConversationEntity( param["tool_calls"] = tool_calls return param - messages.append(message_convert(response)) + history.messages.append(message_convert(response)) tool_calls = response.tool_calls if not tool_calls or not llm_api: @@ -309,7 +331,7 @@ class OpenAIConversationEntity( tool_response["error_text"] = str(e) LOGGER.debug("Tool response: %s", tool_response) - messages.append( + history.messages.append( ChatCompletionToolMessageParam( role="tool", tool_call_id=tool_call.id, @@ -317,7 +339,7 @@ class OpenAIConversationEntity( ) ) - self.history[conversation_id] = messages + self.history[conversation_id] = history intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response.content or "") diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index e0665bc449f..eed396786e2 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -149,6 +149,107 @@ async def test_template_variables( ) +async def test_extra_systen_prompt( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template variables work.""" + extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it." + extra_system_prompt2 = ( + "User person.paulus came home. Asked him what he wants to do." + ) + + with ( + patch( + "openai.resources.models.AsyncModels.list", + ), + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, + "hello", + None, + Context(), + agent_id=mock_config_entry.entry_id, + extra_system_prompt=extra_system_prompt, + ) + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( + extra_system_prompt + ) + + conversation_id = result.conversation_id + + # Verify that follow-up conversations with no system prompt take previous one + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create: + result = await conversation.async_converse( + hass, + "hello", + conversation_id, + Context(), + agent_id=mock_config_entry.entry_id, + extra_system_prompt=None, + ) + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( + extra_system_prompt + ) + + # Verify that we take new system prompts + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create: + result = await conversation.async_converse( + hass, + "hello", + conversation_id, + Context(), + agent_id=mock_config_entry.entry_id, + extra_system_prompt=extra_system_prompt2, + ) + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( + extra_system_prompt2 + ) + + # Verify that follow-up conversations with no system prompt take previous one + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create: + result = await conversation.async_converse( + hass, + "hello", + conversation_id, + Context(), + agent_id=mock_config_entry.entry_id, + ) + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith( + extra_system_prompt2 + ) + + async def test_conversation_agent( hass: HomeAssistant, mock_config_entry: MockConfigEntry,