mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Update OpenAI prompt on each interaction (#118747)
This commit is contained in:
parent
8ea3a6843a
commit
299c0de968
@ -146,58 +146,58 @@ class OpenAIConversationEntity(
|
||||
messages = self.history[conversation_id]
|
||||
else:
|
||||
conversation_id = ulid.ulid_now()
|
||||
messages = []
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(
|
||||
user_input.context.user_id
|
||||
)
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
|
||||
try:
|
||||
if llm_api:
|
||||
api_prompt = llm_api.api_prompt
|
||||
else:
|
||||
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
||||
|
||||
prompt = "\n".join(
|
||||
(
|
||||
template.Template(
|
||||
llm.BASE_PROMPT
|
||||
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
),
|
||||
api_prompt,
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
)
|
||||
|
||||
try:
|
||||
if llm_api:
|
||||
api_prompt = llm_api.api_prompt
|
||||
else:
|
||||
api_prompt = llm.async_render_no_api_prompt(self.hass)
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
prompt = "\n".join(
|
||||
(
|
||||
template.Template(
|
||||
llm.BASE_PROMPT
|
||||
+ options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
),
|
||||
api_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem with my template: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
messages = [ChatCompletionSystemMessageParam(role="system", content=prompt)]
|
||||
|
||||
messages.append(
|
||||
ChatCompletionUserMessageParam(role="user", content=user_input.text)
|
||||
)
|
||||
# Create a copy of the variable because we attach it to the trace
|
||||
messages = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=prompt),
|
||||
*messages[1:],
|
||||
ChatCompletionUserMessageParam(role="user", content=user_input.text),
|
||||
]
|
||||
|
||||
LOGGER.debug("Prompt: %s", messages)
|
||||
trace.async_conversation_trace_append(
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from httpx import Response
|
||||
from openai import RateLimitError
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
@ -214,11 +215,14 @@ async def test_function_call(
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=completion_result,
|
||||
) as mock_create:
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=completion_result,
|
||||
) as mock_create,
|
||||
freeze_time("2024-06-03 23:00:00"),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
@ -227,6 +231,11 @@ async def test_function_call(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
"Today's date is 2024-06-03."
|
||||
in mock_create.mock_calls[1][2]["messages"][0]["content"]
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_create.mock_calls[1][2]["messages"][3] == {
|
||||
"role": "tool",
|
||||
@ -262,6 +271,37 @@ async def test_function_call(
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||
assert (
|
||||
"Today's date is 2024-06-03."
|
||||
in trace_events[1]["data"]["messages"][0]["content"]
|
||||
)
|
||||
|
||||
# Call it again, make sure we have updated prompt
|
||||
with (
|
||||
patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=completion_result,
|
||||
) as mock_create,
|
||||
freeze_time("2024-06-04 23:00:00"),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
"Today's date is 2024-06-04."
|
||||
in mock_create.mock_calls[1][2]["messages"][0]["content"]
|
||||
)
|
||||
# Test old assert message not updated
|
||||
assert (
|
||||
"Today's date is 2024-06-03."
|
||||
in trace_events[1]["data"]["messages"][0]["content"]
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
|
Loading…
x
Reference in New Issue
Block a user