From 299c0de968345ccd0bb70cf5eb8e70e835ee2f32 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 3 Jun 2024 16:27:05 -0400 Subject: [PATCH] Update OpenAI prompt on each interaction (#118747) --- .../openai_conversation/conversation.py | 96 +++++++++---------- .../openai_conversation/test_conversation.py | 50 +++++++++- 2 files changed, 93 insertions(+), 53 deletions(-) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 306e4134b9e..d5e566678f1 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -146,58 +146,58 @@ class OpenAIConversationEntity( messages = self.history[conversation_id] else: conversation_id = ulid.ulid_now() + messages = [] - if ( - user_input.context - and user_input.context.user_id - and ( - user := await self.hass.auth.async_get_user( - user_input.context.user_id - ) + if ( + user_input.context + and user_input.context.user_id + and ( + user := await self.hass.auth.async_get_user(user_input.context.user_id) + ) + ): + user_name = user.name + + try: + if llm_api: + api_prompt = llm_api.api_prompt + else: + api_prompt = llm.async_render_no_api_prompt(self.hass) + + prompt = "\n".join( + ( + template.Template( + llm.BASE_PROMPT + + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), + self.hass, + ).async_render( + { + "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, + }, + parse_result=False, + ), + api_prompt, ) - ): - user_name = user.name + ) - try: - if llm_api: - api_prompt = llm_api.api_prompt - else: - api_prompt = llm.async_render_no_api_prompt(self.hass) + except TemplateError as err: + LOGGER.error("Error rendering prompt: %s", err) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem with my template: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) - prompt = "\n".join( - ( - template.Template( - llm.BASE_PROMPT - + options.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), - self.hass, - ).async_render( - { - "ha_name": self.hass.config.location_name, - "user_name": user_name, - "llm_context": llm_context, - }, - parse_result=False, - ), - api_prompt, - ) - ) - - except TemplateError as err: - LOGGER.error("Error rendering prompt: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem with my template: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - messages = [ChatCompletionSystemMessageParam(role="system", content=prompt)] - - messages.append( - ChatCompletionUserMessageParam(role="user", content=user_input.text) - ) + # Create a copy of the variable because we attach it to the trace + messages = [ + ChatCompletionSystemMessageParam(role="system", content=prompt), + *messages[1:], + ChatCompletionUserMessageParam(role="user", content=user_input.text), + ] LOGGER.debug("Prompt: %s", messages) trace.async_conversation_trace_append( diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 05d62ffd61b..002b2df186b 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, Mock, patch +from freezegun import freeze_time from httpx import Response from openai import RateLimitError from openai.types.chat.chat_completion import ChatCompletion, Choice @@ -214,11 +215,14 @@ async def test_function_call( ), ) - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - side_effect=completion_result, - ) as mock_create: + with ( + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + side_effect=completion_result, + ) as mock_create, + freeze_time("2024-06-03 23:00:00"), + ): result = await conversation.async_converse( hass, "Please call the test function", @@ -227,6 +231,11 @@ async def test_function_call( agent_id=agent_id, ) + assert ( + "Today's date is 2024-06-03." + in mock_create.mock_calls[1][2]["messages"][0]["content"] + ) + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert mock_create.mock_calls[1][2]["messages"][3] == { "role": "tool", @@ -262,6 +271,37 @@ async def test_function_call( # AGENT_DETAIL event contains the raw prompt passed to the model detail_event = trace_events[1] assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] + assert ( + "Today's date is 2024-06-03." + in trace_events[1]["data"]["messages"][0]["content"] + ) + + # Call it again, make sure we have updated prompt + with ( + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + side_effect=completion_result, + ) as mock_create, + freeze_time("2024-06-04 23:00:00"), + ): + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert ( + "Today's date is 2024-06-04." + in mock_create.mock_calls[1][2]["messages"][0]["content"] + ) + # Test old assert message not updated + assert ( + "Today's date is 2024-06-03." + in trace_events[1]["data"]["messages"][0]["content"] + ) @patch(