From 46da43d09daef72192b167214d50174276815f2c Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Sat, 1 Jun 2024 03:28:23 +0800 Subject: [PATCH] Add OpenAI Conversation system prompt `user_name` and `llm_context` variables (#118512) * OpenAI Conversation: Add variables to the system prompt * User name and llm_context * test for user name * test for user id --------- Co-authored-by: Paulus Schoutsen --- .../openai_conversation/conversation.py | 32 ++++++++--- .../openai_conversation/test_conversation.py | 53 ++++++++++++++++++- 2 files changed, 75 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 6da56d3f9a0..7cf4d18cce5 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -113,20 +113,22 @@ class OpenAIConversationEntity( intent_response = intent.IntentResponse(language=user_input.language) llm_api: llm.APIInstance | None = None tools: list[ChatCompletionToolParam] | None = None + user_name: str | None = None + llm_context = llm.LLMContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) if options.get(CONF_LLM_HASS_API): try: llm_api = await llm.async_get_api( self.hass, options[CONF_LLM_HASS_API], - llm.LLMContext( - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=conversation.DOMAIN, - device_id=user_input.device_id, - ), + llm_context, ) except HomeAssistantError as err: LOGGER.error("Error getting LLM API: %s", err) @@ -144,6 +146,18 @@ class OpenAIConversationEntity( messages = self.history[conversation_id] else: conversation_id = ulid.ulid_now() + + if ( + user_input.context + and user_input.context.user_id + and ( + user := await self.hass.auth.async_get_user( + user_input.context.user_id + ) + ) + ): + user_name = user.name + try: if llm_api: api_prompt = llm_api.api_prompt @@ -158,6 +172,8 @@ class OpenAIConversationEntity( ).async_render( { "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, }, parse_result=False, ), diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 10829db7575..05d62ffd61b 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -1,6 +1,6 @@ """Tests for the OpenAI integration.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch from httpx import Response from openai import RateLimitError @@ -73,6 +73,53 @@ async def test_template_error( assert result.response.error_code == "unknown", result +async def test_template_variables( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template variables work.""" + context = Context(user_id="12345") + mock_user = Mock() + mock_user.id = "12345" + mock_user.name = "Test User" + + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": ( + "The user name is {{ user_name }}. " + "The user id is {{ llm_context.context.user_id }}." + ), + }, + ) + with ( + patch( + "openai.resources.models.AsyncModels.list", + ), + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create, + patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, context, agent_id=mock_config_entry.entry_id + ) + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + assert ( + "The user name is Test User." + in mock_create.mock_calls[0][2]["messages"][0]["content"] + ) + assert ( + "The user id is 12345." + in mock_create.mock_calls[0][2]["messages"][0]["content"] + ) + + async def test_conversation_agent( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -382,7 +429,9 @@ async def test_assist_api_tools_conversion( ), ), ) as mock_create: - await conversation.async_converse(hass, "hello", None, None, agent_id=agent_id) + await conversation.async_converse( + hass, "hello", None, Context(), agent_id=agent_id + ) tools = mock_create.mock_calls[0][2]["tools"] assert tools