Tweak Assist LLM API prompt (#118343)

This commit is contained in:
Paulus Schoutsen 2024-05-28 22:43:22 -04:00 committed by GitHub
parent d223e1f2ac
commit c097a05ed4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 26 deletions

View File

@ -227,12 +227,13 @@ class AssistAPI(API):
return APIInstance( return APIInstance(
api=self, api=self,
api_prompt=await self._async_get_api_prompt(tool_context, exposed_entities), api_prompt=self._async_get_api_prompt(tool_context, exposed_entities),
tool_context=tool_context, tool_context=tool_context,
tools=self._async_get_tools(tool_context, exposed_entities), tools=self._async_get_tools(tool_context, exposed_entities),
) )
async def _async_get_api_prompt( @callback
def _async_get_api_prompt(
self, tool_context: ToolContext, exposed_entities: dict | None self, tool_context: ToolContext, exposed_entities: dict | None
) -> str: ) -> str:
"""Return the prompt for the API.""" """Return the prompt for the API."""
@ -269,15 +270,10 @@ class AssistAPI(API):
prompt.append(f"You are in area {area.name} {extra}") prompt.append(f"You are in area {area.name} {extra}")
else: else:
prompt.append( prompt.append(
"Reject all generic commands like 'turn on the lights' because we " "When a user asks to turn on all devices of a specific type, "
"don't know in what area this conversation is happening." "ask user to specify an area."
) )
if tool_context.context and tool_context.context.user_id:
user = await self.hass.auth.async_get_user(tool_context.context.user_id)
if user:
prompt.append(f"The user name is {user.name}.")
if not tool_context.device_id or not async_device_supports_timers( if not tool_context.device_id or not async_device_supports_timers(
self.hass, tool_context.device_id self.hass, tool_context.device_id
): ):

View File

@ -1,6 +1,6 @@
"""Tests for the llm helpers.""" """Tests for the llm helpers."""
from unittest.mock import Mock, patch from unittest.mock import patch
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -430,8 +430,8 @@ async def test_assist_api_prompt(
no_timer_prompt = "This device does not support timers." no_timer_prompt = "This device does not support timers."
area_prompt = ( area_prompt = (
"Reject all generic commands like 'turn on the lights' because we don't know in what area " "When a user asks to turn on all devices of a specific type, "
"this conversation is happening." "ask user to specify an area."
) )
api = await llm.async_get_api(hass, "assist", tool_context) api = await llm.async_get_api(hass, "assist", tool_context)
assert api.api_prompt == ( assert api.api_prompt == (
@ -478,19 +478,5 @@ async def test_assist_api_prompt(
assert api.api_prompt == ( assert api.api_prompt == (
f"""{first_part_prompt} f"""{first_part_prompt}
{area_prompt} {area_prompt}
{exposed_entities_prompt}"""
)
# Add user
context.user_id = "12345"
mock_user = Mock()
mock_user.id = "12345"
mock_user.name = "Test User"
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
api = await llm.async_get_api(hass, "assist", tool_context)
assert api.api_prompt == (
f"""{first_part_prompt}
{area_prompt}
The user name is Test User.
{exposed_entities_prompt}""" {exposed_entities_prompt}"""
) )