mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Add support for extra_system_prompt to OpenAI (#134931)
This commit is contained in:
parent
9532e98166
commit
d13c14eedb
@ -1,6 +1,7 @@
|
|||||||
"""Conversation support for OpenAI."""
|
"""Conversation support for OpenAI."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -73,6 +74,14 @@ def _format_tool(
|
|||||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatHistory:
|
||||||
|
"""Class holding the chat history."""
|
||||||
|
|
||||||
|
extra_system_prompt: str | None = None
|
||||||
|
messages: list[ChatCompletionMessageParam] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConversationEntity(
|
class OpenAIConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
@ -84,7 +93,7 @@ class OpenAIConversationEntity(
|
|||||||
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
self.history: dict[str, ChatHistory] = {}
|
||||||
self._attr_unique_id = entry.entry_id
|
self._attr_unique_id = entry.entry_id
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
identifiers={(DOMAIN, entry.entry_id)},
|
identifiers={(DOMAIN, entry.entry_id)},
|
||||||
@ -157,13 +166,14 @@ class OpenAIConversationEntity(
|
|||||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
|
history: ChatHistory | None = None
|
||||||
|
|
||||||
if user_input.conversation_id is None:
|
if user_input.conversation_id is None:
|
||||||
conversation_id = ulid.ulid_now()
|
conversation_id = ulid.ulid_now()
|
||||||
messages = []
|
|
||||||
|
|
||||||
elif user_input.conversation_id in self.history:
|
elif user_input.conversation_id in self.history:
|
||||||
conversation_id = user_input.conversation_id
|
conversation_id = user_input.conversation_id
|
||||||
messages = self.history[conversation_id]
|
history = self.history.get(conversation_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
||||||
@ -176,7 +186,8 @@ class OpenAIConversationEntity(
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
conversation_id = user_input.conversation_id
|
conversation_id = user_input.conversation_id
|
||||||
|
|
||||||
messages = []
|
if history is None:
|
||||||
|
history = ChatHistory(user_input.extra_system_prompt)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_input.context
|
user_input.context
|
||||||
@ -217,20 +228,31 @@ class OpenAIConversationEntity(
|
|||||||
if llm_api:
|
if llm_api:
|
||||||
prompt_parts.append(llm_api.api_prompt)
|
prompt_parts.append(llm_api.api_prompt)
|
||||||
|
|
||||||
|
extra_system_prompt = (
|
||||||
|
# Take new system prompt if one was given
|
||||||
|
user_input.extra_system_prompt or history.extra_system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
if extra_system_prompt:
|
||||||
|
prompt_parts.append(extra_system_prompt)
|
||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
|
|
||||||
# Create a copy of the variable because we attach it to the trace
|
# Create a copy of the variable because we attach it to the trace
|
||||||
messages = [
|
history = ChatHistory(
|
||||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
extra_system_prompt,
|
||||||
*messages[1:],
|
[
|
||||||
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||||
]
|
*history.messages[1:],
|
||||||
|
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
LOGGER.debug("Prompt: %s", messages)
|
LOGGER.debug("Prompt: %s", history.messages)
|
||||||
LOGGER.debug("Tools: %s", tools)
|
LOGGER.debug("Tools: %s", tools)
|
||||||
trace.async_conversation_trace_append(
|
trace.async_conversation_trace_append(
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
{"messages": messages, "tools": llm_api.tools if llm_api else None},
|
{"messages": history.messages, "tools": llm_api.tools if llm_api else None},
|
||||||
)
|
)
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
@ -240,7 +262,7 @@ class OpenAIConversationEntity(
|
|||||||
try:
|
try:
|
||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(
|
||||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
messages=messages,
|
messages=history.messages,
|
||||||
tools=tools or NOT_GIVEN,
|
tools=tools or NOT_GIVEN,
|
||||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
@ -286,7 +308,7 @@ class OpenAIConversationEntity(
|
|||||||
param["tool_calls"] = tool_calls
|
param["tool_calls"] = tool_calls
|
||||||
return param
|
return param
|
||||||
|
|
||||||
messages.append(message_convert(response))
|
history.messages.append(message_convert(response))
|
||||||
tool_calls = response.tool_calls
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
if not tool_calls or not llm_api:
|
if not tool_calls or not llm_api:
|
||||||
@ -309,7 +331,7 @@ class OpenAIConversationEntity(
|
|||||||
tool_response["error_text"] = str(e)
|
tool_response["error_text"] = str(e)
|
||||||
|
|
||||||
LOGGER.debug("Tool response: %s", tool_response)
|
LOGGER.debug("Tool response: %s", tool_response)
|
||||||
messages.append(
|
history.messages.append(
|
||||||
ChatCompletionToolMessageParam(
|
ChatCompletionToolMessageParam(
|
||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=tool_call.id,
|
tool_call_id=tool_call.id,
|
||||||
@ -317,7 +339,7 @@ class OpenAIConversationEntity(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.history[conversation_id] = messages
|
self.history[conversation_id] = history
|
||||||
|
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(response.content or "")
|
intent_response.async_set_speech(response.content or "")
|
||||||
|
@ -149,6 +149,107 @@ async def test_template_variables(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_extra_systen_prompt(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
) -> None:
|
||||||
|
"""Test that template variables work."""
|
||||||
|
extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it."
|
||||||
|
extra_system_prompt2 = (
|
||||||
|
"User person.paulus came home. Asked him what he wants to do."
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"openai.resources.models.AsyncModels.list",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_create,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"hello",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
extra_system_prompt=extra_system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||||
|
extra_system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_id = result.conversation_id
|
||||||
|
|
||||||
|
# Verify that follow-up conversations with no system prompt take previous one
|
||||||
|
with patch(
|
||||||
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_create:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"hello",
|
||||||
|
conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
extra_system_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||||
|
extra_system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that we take new system prompts
|
||||||
|
with patch(
|
||||||
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_create:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"hello",
|
||||||
|
conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
extra_system_prompt=extra_system_prompt2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||||
|
extra_system_prompt2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that follow-up conversations with no system prompt take previous one
|
||||||
|
with patch(
|
||||||
|
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_create:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"hello",
|
||||||
|
conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
assert mock_create.mock_calls[0][2]["messages"][0]["content"].endswith(
|
||||||
|
extra_system_prompt2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_conversation_agent(
|
async def test_conversation_agent(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user