From c097a05ed448275927c8bfdb0234228d19630068 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 28 May 2024 22:43:22 -0400 Subject: [PATCH] Tweak Assist LLM API prompt (#118343) --- homeassistant/helpers/llm.py | 14 +++++--------- tests/helpers/test_llm.py | 20 +++----------------- 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 2f808321c13..ae6cbbe672f 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -227,12 +227,13 @@ class AssistAPI(API): return APIInstance( api=self, - api_prompt=await self._async_get_api_prompt(tool_context, exposed_entities), + api_prompt=self._async_get_api_prompt(tool_context, exposed_entities), tool_context=tool_context, tools=self._async_get_tools(tool_context, exposed_entities), ) - async def _async_get_api_prompt( + @callback + def _async_get_api_prompt( self, tool_context: ToolContext, exposed_entities: dict | None ) -> str: """Return the prompt for the API.""" @@ -269,15 +270,10 @@ class AssistAPI(API): prompt.append(f"You are in area {area.name} {extra}") else: prompt.append( - "Reject all generic commands like 'turn on the lights' because we " - "don't know in what area this conversation is happening." + "When a user asks to turn on all devices of a specific type, " + "ask user to specify an area." ) - if tool_context.context and tool_context.context.user_id: - user = await self.hass.auth.async_get_user(tool_context.context.user_id) - if user: - prompt.append(f"The user name is {user.name}.") - if not tool_context.device_id or not async_device_supports_timers( self.hass, tool_context.device_id ): diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index c71d11da8a2..4aeb0cd93b7 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -1,6 +1,6 @@ """Tests for the llm helpers.""" -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest import voluptuous as vol @@ -430,8 +430,8 @@ async def test_assist_api_prompt( no_timer_prompt = "This device does not support timers." area_prompt = ( - "Reject all generic commands like 'turn on the lights' because we don't know in what area " - "this conversation is happening." + "When a user asks to turn on all devices of a specific type, " + "ask user to specify an area." ) api = await llm.async_get_api(hass, "assist", tool_context) assert api.api_prompt == ( @@ -478,19 +478,5 @@ async def test_assist_api_prompt( assert api.api_prompt == ( f"""{first_part_prompt} {area_prompt} -{exposed_entities_prompt}""" - ) - - # Add user - context.user_id = "12345" - mock_user = Mock() - mock_user.id = "12345" - mock_user.name = "Test User" - with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user): - api = await llm.async_get_api(hass, "assist", tool_context) - assert api.api_prompt == ( - f"""{first_part_prompt} -{area_prompt} -The user name is Test User. {exposed_entities_prompt}""" )