mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Allow llm API to render dynamic template prompt (#118055)
* Allow llm API to render dynamic template prompt * Make rendering api prompt async so it can become a RAG * Fix test
This commit is contained in:
parent
3b2cdb63f1
commit
7554ca9460
@ -36,7 +36,6 @@ from .const import (
|
||||
CONF_PROMPT,
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TONE_PROMPT,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_PROMPT,
|
||||
@ -59,7 +58,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
RECOMMENDED_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
CONF_TONE_PROMPT: "",
|
||||
CONF_PROMPT: "",
|
||||
}
|
||||
|
||||
|
||||
@ -142,16 +141,11 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
|
||||
# Re-render the options again, now with the recommended options shown/hidden
|
||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||
|
||||
# If we switch to not recommended, generate used prompt.
|
||||
if user_input[CONF_RECOMMENDED]:
|
||||
options = RECOMMENDED_OPTIONS
|
||||
else:
|
||||
options = {
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: DEFAULT_PROMPT
|
||||
+ "\n"
|
||||
+ user_input.get(CONF_TONE_PROMPT, ""),
|
||||
}
|
||||
options = {
|
||||
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||
}
|
||||
|
||||
schema = await google_generative_ai_config_option_schema(self.hass, options)
|
||||
return self.async_show_form(
|
||||
@ -179,22 +173,24 @@ async def google_generative_ai_config_option_schema(
|
||||
for api in llm.async_get_apis(hass)
|
||||
)
|
||||
|
||||
schema = {
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||
default=DEFAULT_PROMPT,
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
|
||||
if options.get(CONF_RECOMMENDED):
|
||||
return {
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
vol.Optional(
|
||||
CONF_TONE_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_TONE_PROMPT)},
|
||||
default="",
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||
}
|
||||
return schema
|
||||
|
||||
api_models = await hass.async_add_executor_job(partial(genai.list_models))
|
||||
|
||||
@ -211,45 +207,35 @@ async def google_generative_ai_config_option_schema(
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||
default=RECOMMENDED_CHAT_MODEL,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||
default=DEFAULT_PROMPT,
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=RECOMMENDED_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=RECOMMENDED_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TOP_K,
|
||||
description={"suggested_value": options.get(CONF_TOP_K)},
|
||||
default=RECOMMENDED_TOP_K,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=RECOMMENDED_MAX_TOKENS,
|
||||
): int,
|
||||
}
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||
default=RECOMMENDED_CHAT_MODEL,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=RECOMMENDED_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=RECOMMENDED_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TOP_K,
|
||||
description={"suggested_value": options.get(CONF_TOP_K)},
|
||||
default=RECOMMENDED_TOP_K,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=RECOMMENDED_MAX_TOKENS,
|
||||
): int,
|
||||
}
|
||||
)
|
||||
return schema
|
||||
|
@ -5,24 +5,7 @@ import logging
|
||||
DOMAIN = "google_generative_ai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
CONF_TONE_PROMPT = "tone_prompt"
|
||||
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
{%- for area in areas() %}
|
||||
{%- set area_info = namespace(printed=false) %}
|
||||
{%- for device in area_devices(area) -%}
|
||||
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
|
||||
{%- if not area_info.printed %}
|
||||
|
||||
{{ area_name(area) }}:
|
||||
{%- set area_info.printed = true %}
|
||||
{%- endif %}
|
||||
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"""
|
||||
DEFAULT_PROMPT = "Answer in plain text. Keep it simple and to the point."
|
||||
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
|
@ -25,7 +25,6 @@ from .const import (
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TONE_PROMPT,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_PROMPT,
|
||||
@ -179,12 +178,32 @@ class GoogleGenerativeAIConversationEntity(
|
||||
conversation_id = ulid.ulid_now()
|
||||
messages = [{}, {}]
|
||||
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT):
|
||||
raw_prompt += "\n" + tone_prompt
|
||||
|
||||
try:
|
||||
prompt = self._async_generate_prompt(raw_prompt, llm_api)
|
||||
prompt = template.Template(
|
||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
empty_tool_input = llm.ToolInput(
|
||||
tool_name="",
|
||||
tool_args={},
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
|
||||
)
|
||||
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response.async_set_error(
|
||||
@ -271,18 +290,3 @@ class GoogleGenerativeAIConversationEntity(
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
raw_prompt += "\n"
|
||||
if llm_api:
|
||||
raw_prompt += llm_api.prompt_template
|
||||
else:
|
||||
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
|
||||
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
@ -18,9 +18,8 @@
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"recommended": "Recommended settings",
|
||||
"prompt": "Prompt",
|
||||
"tone_prompt": "Tone",
|
||||
"recommended": "Recommended model settings",
|
||||
"prompt": "Instructions",
|
||||
"chat_model": "[%key:common::generic::model%]",
|
||||
"temperature": "Temperature",
|
||||
"top_p": "Top P",
|
||||
@ -29,8 +28,7 @@
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Extra data to provide to the LLM. This can be a template.",
|
||||
"tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template."
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,23 +5,7 @@ import logging
|
||||
DOMAIN = "openai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
{%- for area in areas() %}
|
||||
{%- set area_info = namespace(printed=false) %}
|
||||
{%- for device in area_devices(area) -%}
|
||||
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
|
||||
{%- if not area_info.printed %}
|
||||
|
||||
{{ area_name(area) }}:
|
||||
{%- set area_info.printed = true %}
|
||||
{%- endif %}
|
||||
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"""
|
||||
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
DEFAULT_CHAT_MODEL = "gpt-4o"
|
||||
CONF_MAX_TOKENS = "max_tokens"
|
||||
|
@ -110,7 +110,6 @@ class OpenAIConversationEntity(
|
||||
)
|
||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
||||
|
||||
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
@ -122,10 +121,33 @@ class OpenAIConversationEntity(
|
||||
else:
|
||||
conversation_id = ulid.ulid_now()
|
||||
try:
|
||||
prompt = self._async_generate_prompt(
|
||||
raw_prompt,
|
||||
llm_api,
|
||||
prompt = template.Template(
|
||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
empty_tool_input = llm.ToolInput(
|
||||
tool_name="",
|
||||
tool_args={},
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
await llm_api.async_get_api_prompt(empty_tool_input)
|
||||
+ "\n"
|
||||
+ prompt
|
||||
)
|
||||
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
@ -136,6 +158,7 @@ class OpenAIConversationEntity(
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": prompt}]
|
||||
|
||||
messages.append({"role": "user", "content": user_input.text})
|
||||
@ -213,22 +236,3 @@ class OpenAIConversationEntity(
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _async_generate_prompt(
|
||||
self,
|
||||
raw_prompt: str,
|
||||
llm_api: llm.API | None,
|
||||
) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
raw_prompt += "\n"
|
||||
if llm_api:
|
||||
raw_prompt += llm_api.prompt_template
|
||||
else:
|
||||
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
|
||||
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
@ -102,7 +102,11 @@ class API(ABC):
|
||||
hass: HomeAssistant
|
||||
id: str
|
||||
name: str
|
||||
prompt_template: str
|
||||
|
||||
@abstractmethod
|
||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
||||
"""Return the prompt for the API."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@callback
|
||||
@ -183,9 +187,12 @@ class AssistAPI(API):
|
||||
hass=hass,
|
||||
id=LLM_API_ASSIST,
|
||||
name="Assist",
|
||||
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
|
||||
)
|
||||
|
||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
||||
"""Return the prompt for the API."""
|
||||
return "Call the intent tools to control Home Assistant. Just pass the name to the intent."
|
||||
|
||||
@callback
|
||||
def async_get_tools(self) -> list[Tool]:
|
||||
"""Return a list of LLM tools."""
|
||||
|
@ -23,22 +23,7 @@
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'parts': 'Answer in plain text. Keep it simple and to the point.',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
@ -82,22 +67,7 @@
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'parts': 'Answer in plain text. Keep it simple and to the point.',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
@ -142,20 +112,8 @@
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Call the intent tools to control the system. Just pass the name to the intent.
|
||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
@ -201,20 +159,8 @@
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
This smart home is controlled by Home Assistant.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
|
||||
Test Area:
|
||||
- Test Device (Test Model)
|
||||
|
||||
Test Area 2:
|
||||
- Test Device 2
|
||||
- Test Device 3 (Test Model 3A)
|
||||
- Test Device 4
|
||||
- 1 (3)
|
||||
|
||||
Call the intent tools to control the system. Just pass the name to the intent.
|
||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
|
@ -13,7 +13,6 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_PROMPT,
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TONE_PROMPT,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
DOMAIN,
|
||||
@ -90,7 +89,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
assert result2["options"] == {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
CONF_TONE_PROMPT: "",
|
||||
CONF_PROMPT: "",
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@ -102,7 +101,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "none",
|
||||
CONF_TONE_PROMPT: "bla",
|
||||
CONF_PROMPT: "bla",
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
@ -132,12 +131,12 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
CONF_TONE_PROMPT: "",
|
||||
CONF_PROMPT: "",
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
CONF_TONE_PROMPT: "",
|
||||
CONF_PROMPT: "",
|
||||
},
|
||||
),
|
||||
],
|
||||
|
@ -11,7 +11,6 @@ from openai.types.chat.chat_completion_message_tool_call import (
|
||||
Function,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
@ -19,148 +18,12 @@ from homeassistant.components import conversation
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
intent,
|
||||
llm,
|
||||
)
|
||||
from homeassistant.helpers import intent, llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", [None, "conversation.openai"])
|
||||
@pytest.mark.parametrize(
|
||||
"config_entry_options", [{}, {CONF_LLM_HASS_API: llm.LLM_API_ASSIST}]
|
||||
)
|
||||
async def test_default_prompt(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
agent_id: str,
|
||||
config_entry_options: dict,
|
||||
) -> None:
|
||||
"""Test that the default prompt works."""
|
||||
entry = MockConfigEntry(title=None)
|
||||
entry.add_to_hass(hass)
|
||||
for i in range(3):
|
||||
area_registry.async_create(f"{i}Empty Area")
|
||||
|
||||
if agent_id is None:
|
||||
agent_id = mock_config_entry.entry_id
|
||||
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
**mock_config_entry.options,
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
},
|
||||
)
|
||||
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "1234")},
|
||||
name="Test Device",
|
||||
manufacturer="Test Manufacturer",
|
||||
model="Test Model",
|
||||
suggested_area="Test Area",
|
||||
)
|
||||
for i in range(3):
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", f"{i}abcd")},
|
||||
name="Test Service",
|
||||
manufacturer="Test Manufacturer",
|
||||
model="Test Model",
|
||||
suggested_area="Test Area",
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "5678")},
|
||||
name="Test Device 2",
|
||||
manufacturer="Test Manufacturer 2",
|
||||
model="Device 2",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876")},
|
||||
name="Test Device 3",
|
||||
manufacturer="Test Manufacturer 3",
|
||||
model="Test Model 3A",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "qwer")},
|
||||
name="Test Device 4",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876-disabled")},
|
||||
name="Test Device 3",
|
||||
manufacturer="Test Manufacturer 3",
|
||||
model="Test Model 3A",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
device_registry.async_update_device(
|
||||
device.id, disabled_by=dr.DeviceEntryDisabler.USER
|
||||
)
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876-no-name")},
|
||||
manufacturer="Test Manufacturer NoName",
|
||||
model="Test Model NoName",
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "9876-integer-values")},
|
||||
name=1,
|
||||
manufacturer=2,
|
||||
model=3,
|
||||
suggested_area="Test Area 2",
|
||||
)
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="Hello, how can I help you?",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-3.5-turbo-0613",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
),
|
||||
) as mock_create:
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=agent_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_create.mock_calls[0][2]["messages"] == snapshot
|
||||
|
||||
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
|
@ -20,11 +20,15 @@ async def test_register_api(hass: HomeAssistant) -> None:
|
||||
"""Test registering an llm api."""
|
||||
|
||||
class MyAPI(llm.API):
|
||||
async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str:
|
||||
"""Return a prompt for the tool."""
|
||||
return ""
|
||||
|
||||
def async_get_tools(self) -> list[llm.Tool]:
|
||||
"""Return a list of tools."""
|
||||
return []
|
||||
|
||||
api = MyAPI(hass=hass, id="test", name="Test", prompt_template="")
|
||||
api = MyAPI(hass=hass, id="test", name="Test")
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
assert llm.async_get_api(hass, "test") is api
|
||||
|
Loading…
x
Reference in New Issue
Block a user