Add support for extra_system_prompt to OpenAI (#134931)

This commit is contained in:
Paulus Schoutsen 2025-01-06 17:01:13 -05:00 committed by GitHub
parent 9532e98166
commit d13c14eedb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 138 additions and 15 deletions

View File

@ -1,6 +1,7 @@
"""Conversation support for OpenAI.""" """Conversation support for OpenAI."""
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field
import json import json
from typing import Any, Literal from typing import Any, Literal
@ -73,6 +74,14 @@ def _format_tool(
return ChatCompletionToolParam(type="function", function=tool_spec) return ChatCompletionToolParam(type="function", function=tool_spec)
@dataclass
class ChatHistory:
"""Class holding the chat history."""
extra_system_prompt: str | None = None
messages: list[ChatCompletionMessageParam] = field(default_factory=list)
class OpenAIConversationEntity( class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent conversation.ConversationEntity, conversation.AbstractConversationAgent
): ):
@ -84,7 +93,7 @@ class OpenAIConversationEntity(
def __init__(self, entry: OpenAIConfigEntry) -> None: def __init__(self, entry: OpenAIConfigEntry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self.history: dict[str, list[ChatCompletionMessageParam]] = {} self.history: dict[str, ChatHistory] = {}
self._attr_unique_id = entry.entry_id self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo( self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)}, identifiers={(DOMAIN, entry.entry_id)},
@ -157,13 +166,14 @@ class OpenAIConversationEntity(
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
] ]
history: ChatHistory | None = None
if user_input.conversation_id is None: if user_input.conversation_id is None:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
messages = []
elif user_input.conversation_id in self.history: elif user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
messages = self.history[conversation_id] history = self.history.get(conversation_id)
else: else:
# Conversation IDs are ULIDs. We generate a new one if not provided. # Conversation IDs are ULIDs. We generate a new one if not provided.
@ -176,7 +186,8 @@ class OpenAIConversationEntity(
except ValueError: except ValueError:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
messages = [] if history is None:
history = ChatHistory(user_input.extra_system_prompt)
if ( if (
user_input.context user_input.context
@ -217,20 +228,31 @@ class OpenAIConversationEntity(
if llm_api: if llm_api:
prompt_parts.append(llm_api.api_prompt) prompt_parts.append(llm_api.api_prompt)
extra_system_prompt = (
# Take new system prompt if one was given
user_input.extra_system_prompt or history.extra_system_prompt
)
if extra_system_prompt:
prompt_parts.append(extra_system_prompt)
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
# Create a copy of the variable because we attach it to the trace # Create a copy of the variable because we attach it to the trace
messages = [ history = ChatHistory(
ChatCompletionSystemMessageParam(role="system", content=prompt), extra_system_prompt,
*messages[1:], [
ChatCompletionUserMessageParam(role="user", content=user_input.text), ChatCompletionSystemMessageParam(role="system", content=prompt),
] *history.messages[1:],
ChatCompletionUserMessageParam(role="user", content=user_input.text),
],
)
LOGGER.debug("Prompt: %s", messages) LOGGER.debug("Prompt: %s", history.messages)
LOGGER.debug("Tools: %s", tools) LOGGER.debug("Tools: %s", tools)
trace.async_conversation_trace_append( trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": messages, "tools": llm_api.tools if llm_api else None}, {"messages": history.messages, "tools": llm_api.tools if llm_api else None},
) )
client = self.entry.runtime_data client = self.entry.runtime_data
@ -240,7 +262,7 @@ class OpenAIConversationEntity(
try: try:
result = await client.chat.completions.create( result = await client.chat.completions.create(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages, messages=history.messages,
tools=tools or NOT_GIVEN, tools=tools or NOT_GIVEN,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
@ -286,7 +308,7 @@ class OpenAIConversationEntity(
param["tool_calls"] = tool_calls param["tool_calls"] = tool_calls
return param return param
messages.append(message_convert(response)) history.messages.append(message_convert(response))
tool_calls = response.tool_calls tool_calls = response.tool_calls
if not tool_calls or not llm_api: if not tool_calls or not llm_api:
@ -309,7 +331,7 @@ class OpenAIConversationEntity(
tool_response["error_text"] = str(e) tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response) LOGGER.debug("Tool response: %s", tool_response)
messages.append( history.messages.append(
ChatCompletionToolMessageParam( ChatCompletionToolMessageParam(
role="tool", role="tool",
tool_call_id=tool_call.id, tool_call_id=tool_call.id,
@ -317,7 +339,7 @@ class OpenAIConversationEntity(
) )
) )
self.history[conversation_id] = messages self.history[conversation_id] = history
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response.content or "") intent_response.async_set_speech(response.content or "")

View File

@ -149,6 +149,107 @@ async def test_template_variables(
) )
async def test_extra_systen_prompt(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template variables work."""
extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it."
extra_system_prompt2 = (
"User person.paulus came home. Asked him what he wants to do."
)
with (
patch(
"openai.resources.models.AsyncModels.list",
),
patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id=mock_config_entry.entry_id,
extra_system_prompt=extra_system_prompt,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt
)
conversation_id = result.conversation_id
# Verify that follow-up conversations with no system prompt take previous one
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create:
result = await conversation.async_converse(
hass,
"hello",
conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
extra_system_prompt=None,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt
)
# Verify that we take new system prompts
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create:
result = await conversation.async_converse(
hass,
"hello",
conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
extra_system_prompt=extra_system_prompt2,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt2
)
# Verify that follow-up conversations with no system prompt take previous one
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
) as mock_create:
result = await conversation.async_converse(
hass,
"hello",
conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
extra_system_prompt2
)
async def test_conversation_agent( async def test_conversation_agent(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,