Allow llm API to render dynamic template prompt (#118055)

* Allow llm API to render dynamic template prompt

* Make rendering api prompt async so it can become a RAG

* Fix test
This commit is contained in:
Paulus Schoutsen 2024-05-24 16:04:48 -04:00 committed by GitHub
parent 3b2cdb63f1
commit 7554ca9460
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 137 additions and 359 deletions

View File

@ -36,7 +36,6 @@ from .const import (
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_PROMPT,
@ -59,7 +58,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_TONE_PROMPT: "",
CONF_PROMPT: "",
}
@ -142,16 +141,11 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
# If we switch to not recommended, generate used prompt.
if user_input[CONF_RECOMMENDED]:
options = RECOMMENDED_OPTIONS
else:
options = {
CONF_RECOMMENDED: False,
CONF_PROMPT: DEFAULT_PROMPT
+ "\n"
+ user_input.get(CONF_TONE_PROMPT, ""),
}
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
schema = await google_generative_ai_config_option_schema(self.hass, options)
return self.async_show_form(
@ -179,22 +173,24 @@ async def google_generative_ai_config_option_schema(
for api in llm.async_get_apis(hass)
)
schema = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
if options.get(CONF_RECOMMENDED):
return {
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_TONE_PROMPT,
description={"suggested_value": options.get(CONF_TONE_PROMPT)},
default="",
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
}
return schema
api_models = await hass.async_add_executor_job(partial(genai.list_models))
@ -211,45 +207,35 @@ async def google_generative_ai_config_option_schema(
)
]
return {
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): SelectSelector(
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
),
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=RECOMMENDED_TOP_K,
): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
}
schema.update(
{
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): SelectSelector(
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=RECOMMENDED_TOP_K,
): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
}
)
return schema

View File

@ -5,24 +5,7 @@ import logging
DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
CONF_TONE_PROMPT = "tone_prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
{%- for area in areas() %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
{%- if not area_info.printed %}
{{ area_name(area) }}:
{%- set area_info.printed = true %}
{%- endif %}
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %}
{%- endfor %}
{%- endfor %}
"""
DEFAULT_PROMPT = "Answer in plain text. Keep it simple and to the point."
CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model"

View File

@ -25,7 +25,6 @@ from .const import (
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_PROMPT,
@ -179,12 +178,32 @@ class GoogleGenerativeAIConversationEntity(
conversation_id = ulid.ulid_now()
messages = [{}, {}]
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT):
raw_prompt += "\n" + tone_prompt
try:
prompt = self._async_generate_prompt(raw_prompt, llm_api)
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
prompt = (
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
@ -271,18 +290,3 @@ class GoogleGenerativeAIConversationEntity(
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View File

@ -18,9 +18,8 @@
"step": {
"init": {
"data": {
"recommended": "Recommended settings",
"prompt": "Prompt",
"tone_prompt": "Tone",
"recommended": "Recommended model settings",
"prompt": "Instructions",
"chat_model": "[%key:common::generic::model%]",
"temperature": "Temperature",
"top_p": "Top P",
@ -29,8 +28,7 @@
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
},
"data_description": {
"prompt": "Extra data to provide to the LLM. This can be a template.",
"tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template."
"prompt": "Instruct how the LLM should respond. This can be a template."
}
}
}

View File

@ -5,23 +5,7 @@ import logging
DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
{%- for area in areas() %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
{%- if not area_info.printed %}
{{ area_name(area) }}:
{%- set area_info.printed = true %}
{%- endif %}
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %}
{%- endfor %}
{%- endfor %}
"""
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens"

View File

@ -110,7 +110,6 @@ class OpenAIConversationEntity(
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
@ -122,10 +121,33 @@ class OpenAIConversationEntity(
else:
conversation_id = ulid.ulid_now()
try:
prompt = self._async_generate_prompt(
raw_prompt,
llm_api,
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
prompt = (
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n"
+ prompt
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
@ -136,6 +158,7 @@ class OpenAIConversationEntity(
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = [{"role": "system", "content": prompt}]
messages.append({"role": "user", "content": user_input.text})
@ -213,22 +236,3 @@ class OpenAIConversationEntity(
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _async_generate_prompt(
self,
raw_prompt: str,
llm_api: llm.API | None,
) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View File

@ -102,7 +102,11 @@ class API(ABC):
hass: HomeAssistant
id: str
name: str
prompt_template: str
@abstractmethod
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
raise NotImplementedError
@abstractmethod
@callback
@ -183,9 +187,12 @@ class AssistAPI(API):
hass=hass,
id=LLM_API_ASSIST,
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
)
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
return "Call the intent tools to control Home Assistant. Just pass the name to the intent."
@callback
def async_get_tools(self) -> list[Tool]:
"""Return a list of LLM tools."""

View File

@ -23,22 +23,7 @@
dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'parts': 'Answer in plain text. Keep it simple and to the point.',
'role': 'user',
}),
dict({
@ -82,22 +67,7 @@
dict({
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'parts': 'Answer in plain text. Keep it simple and to the point.',
'role': 'user',
}),
dict({
@ -142,20 +112,8 @@
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point.
''',
'role': 'user',
}),
@ -201,20 +159,8 @@
'history': list([
dict({
'parts': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point.
''',
'role': 'user',
}),

View File

@ -13,7 +13,6 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
DOMAIN,
@ -90,7 +89,7 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["options"] == {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_TONE_PROMPT: "",
CONF_PROMPT: "",
}
assert len(mock_setup_entry.mock_calls) == 1
@ -102,7 +101,7 @@ async def test_form(hass: HomeAssistant) -> None:
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "none",
CONF_TONE_PROMPT: "bla",
CONF_PROMPT: "bla",
},
{
CONF_RECOMMENDED: False,
@ -132,12 +131,12 @@ async def test_form(hass: HomeAssistant) -> None:
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_TONE_PROMPT: "",
CONF_PROMPT: "",
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_TONE_PROMPT: "",
CONF_PROMPT: "",
},
),
],

View File

@ -11,7 +11,6 @@ from openai.types.chat.chat_completion_message_tool_call import (
Function,
)
from openai.types.completion_usage import CompletionUsage
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
@ -19,148 +18,12 @@ from homeassistant.components import conversation
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
intent,
llm,
)
from homeassistant.helpers import intent, llm
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@pytest.mark.parametrize("agent_id", [None, "conversation.openai"])
@pytest.mark.parametrize(
"config_entry_options", [{}, {CONF_LLM_HASS_API: llm.LLM_API_ASSIST}]
)
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
agent_id: str,
config_entry_options: dict,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
if agent_id is None:
agent_id = mock_config_entry.entry_id
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
for i in range(3):
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello, how can I help you?",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-3.5-turbo-0613",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
),
) as mock_create:
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=agent_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[0][2]["messages"] == snapshot
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:

View File

@ -20,11 +20,15 @@ async def test_register_api(hass: HomeAssistant) -> None:
"""Test registering an llm api."""
class MyAPI(llm.API):
async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str:
"""Return a prompt for the tool."""
return ""
def async_get_tools(self) -> list[llm.Tool]:
"""Return a list of tools."""
return []
api = MyAPI(hass=hass, id="test", name="Test", prompt_template="")
api = MyAPI(hass=hass, id="test", name="Test")
llm.async_register_api(hass, api)
assert llm.async_get_api(hass, "test") is api