mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 15:17:35 +00:00
Tweak Assist LLM API prompt (#118343)
This commit is contained in:
parent
d223e1f2ac
commit
c097a05ed4
@ -227,12 +227,13 @@ class AssistAPI(API):
|
|||||||
|
|
||||||
return APIInstance(
|
return APIInstance(
|
||||||
api=self,
|
api=self,
|
||||||
api_prompt=await self._async_get_api_prompt(tool_context, exposed_entities),
|
api_prompt=self._async_get_api_prompt(tool_context, exposed_entities),
|
||||||
tool_context=tool_context,
|
tool_context=tool_context,
|
||||||
tools=self._async_get_tools(tool_context, exposed_entities),
|
tools=self._async_get_tools(tool_context, exposed_entities),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_get_api_prompt(
|
@callback
|
||||||
|
def _async_get_api_prompt(
|
||||||
self, tool_context: ToolContext, exposed_entities: dict | None
|
self, tool_context: ToolContext, exposed_entities: dict | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return the prompt for the API."""
|
"""Return the prompt for the API."""
|
||||||
@ -269,15 +270,10 @@ class AssistAPI(API):
|
|||||||
prompt.append(f"You are in area {area.name} {extra}")
|
prompt.append(f"You are in area {area.name} {extra}")
|
||||||
else:
|
else:
|
||||||
prompt.append(
|
prompt.append(
|
||||||
"Reject all generic commands like 'turn on the lights' because we "
|
"When a user asks to turn on all devices of a specific type, "
|
||||||
"don't know in what area this conversation is happening."
|
"ask user to specify an area."
|
||||||
)
|
)
|
||||||
|
|
||||||
if tool_context.context and tool_context.context.user_id:
|
|
||||||
user = await self.hass.auth.async_get_user(tool_context.context.user_id)
|
|
||||||
if user:
|
|
||||||
prompt.append(f"The user name is {user.name}.")
|
|
||||||
|
|
||||||
if not tool_context.device_id or not async_device_supports_timers(
|
if not tool_context.device_id or not async_device_supports_timers(
|
||||||
self.hass, tool_context.device_id
|
self.hass, tool_context.device_id
|
||||||
):
|
):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Tests for the llm helpers."""
|
"""Tests for the llm helpers."""
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -430,8 +430,8 @@ async def test_assist_api_prompt(
|
|||||||
no_timer_prompt = "This device does not support timers."
|
no_timer_prompt = "This device does not support timers."
|
||||||
|
|
||||||
area_prompt = (
|
area_prompt = (
|
||||||
"Reject all generic commands like 'turn on the lights' because we don't know in what area "
|
"When a user asks to turn on all devices of a specific type, "
|
||||||
"this conversation is happening."
|
"ask user to specify an area."
|
||||||
)
|
)
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
@ -478,19 +478,5 @@ async def test_assist_api_prompt(
|
|||||||
assert api.api_prompt == (
|
assert api.api_prompt == (
|
||||||
f"""{first_part_prompt}
|
f"""{first_part_prompt}
|
||||||
{area_prompt}
|
{area_prompt}
|
||||||
{exposed_entities_prompt}"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add user
|
|
||||||
context.user_id = "12345"
|
|
||||||
mock_user = Mock()
|
|
||||||
mock_user.id = "12345"
|
|
||||||
mock_user.name = "Test User"
|
|
||||||
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
|
|
||||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
|
||||||
assert api.api_prompt == (
|
|
||||||
f"""{first_part_prompt}
|
|
||||||
{area_prompt}
|
|
||||||
The user name is Test User.
|
|
||||||
{exposed_entities_prompt}"""
|
{exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user