Allow llm API to render dynamic template prompt (#118055)

* Allow llm API to render dynamic template prompt

* Make rendering api prompt async so it can become a RAG

* Fix test
This commit is contained in:
Paulus Schoutsen 2024-05-24 16:04:48 -04:00 committed by GitHub
parent 3b2cdb63f1
commit 7554ca9460
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 137 additions and 359 deletions

View File

@ -36,7 +36,6 @@ from .const import (
CONF_PROMPT, CONF_PROMPT,
CONF_RECOMMENDED, CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_PROMPT, DEFAULT_PROMPT,
@ -59,7 +58,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
RECOMMENDED_OPTIONS = { RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST, CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_TONE_PROMPT: "", CONF_PROMPT: "",
} }
@ -142,16 +141,11 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
# Re-render the options again, now with the recommended options shown/hidden # Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED] self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
# If we switch to not recommended, generate used prompt. options = {
if user_input[CONF_RECOMMENDED]: CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
options = RECOMMENDED_OPTIONS CONF_PROMPT: user_input[CONF_PROMPT],
else: CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
options = { }
CONF_RECOMMENDED: False,
CONF_PROMPT: DEFAULT_PROMPT
+ "\n"
+ user_input.get(CONF_TONE_PROMPT, ""),
}
schema = await google_generative_ai_config_option_schema(self.hass, options) schema = await google_generative_ai_config_option_schema(self.hass, options)
return self.async_show_form( return self.async_show_form(
@ -179,22 +173,24 @@ async def google_generative_ai_config_option_schema(
for api in llm.async_get_apis(hass) for api in llm.async_get_apis(hass)
) )
schema = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
if options.get(CONF_RECOMMENDED): if options.get(CONF_RECOMMENDED):
return { return schema
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_TONE_PROMPT,
description={"suggested_value": options.get(CONF_TONE_PROMPT)},
default="",
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
}
api_models = await hass.async_add_executor_job(partial(genai.list_models)) api_models = await hass.async_add_executor_job(partial(genai.list_models))
@ -211,45 +207,35 @@ async def google_generative_ai_config_option_schema(
) )
] ]
return { schema.update(
vol.Required( {
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) vol.Optional(
): bool, CONF_CHAT_MODEL,
vol.Optional( description={"suggested_value": options.get(CONF_CHAT_MODEL)},
CONF_CHAT_MODEL, default=RECOMMENDED_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)}, ): SelectSelector(
default=RECOMMENDED_CHAT_MODEL, SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
): SelectSelector( ),
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models) vol.Optional(
), CONF_TEMPERATURE,
vol.Optional( description={"suggested_value": options.get(CONF_TEMPERATURE)},
CONF_PROMPT, default=RECOMMENDED_TEMPERATURE,
description={"suggested_value": options.get(CONF_PROMPT)}, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
default=DEFAULT_PROMPT, vol.Optional(
): TemplateSelector(), CONF_TOP_P,
vol.Optional( description={"suggested_value": options.get(CONF_TOP_P)},
CONF_LLM_HASS_API, default=RECOMMENDED_TOP_P,
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
default="none", vol.Optional(
): SelectSelector(SelectSelectorConfig(options=hass_apis)), CONF_TOP_K,
vol.Optional( description={"suggested_value": options.get(CONF_TOP_K)},
CONF_TEMPERATURE, default=RECOMMENDED_TOP_K,
description={"suggested_value": options.get(CONF_TEMPERATURE)}, ): int,
default=RECOMMENDED_TEMPERATURE, vol.Optional(
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), CONF_MAX_TOKENS,
vol.Optional( description={"suggested_value": options.get(CONF_MAX_TOKENS)},
CONF_TOP_P, default=RECOMMENDED_MAX_TOKENS,
description={"suggested_value": options.get(CONF_TOP_P)}, ): int,
default=RECOMMENDED_TOP_P, }
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), )
vol.Optional( return schema
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=RECOMMENDED_TOP_K,
): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
}

View File

@ -5,24 +5,7 @@ import logging
DOMAIN = "google_generative_ai_conversation" DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
CONF_TONE_PROMPT = "tone_prompt" DEFAULT_PROMPT = "Answer in plain text. Keep it simple and to the point."
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
{%- for area in areas() %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
{%- if not area_info.printed %}
{{ area_name(area) }}:
{%- set area_info.printed = true %}
{%- endif %}
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %}
{%- endfor %}
{%- endfor %}
"""
CONF_RECOMMENDED = "recommended" CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"

View File

@ -25,7 +25,6 @@ from .const import (
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_PROMPT, DEFAULT_PROMPT,
@ -179,12 +178,32 @@ class GoogleGenerativeAIConversationEntity(
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
messages = [{}, {}] messages = [{}, {}]
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT):
raw_prompt += "\n" + tone_prompt
try: try:
prompt = self._async_generate_prompt(raw_prompt, llm_api) prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
prompt = (
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
)
except TemplateError as err: except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err) LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error( intent_response.async_set_error(
@ -271,18 +290,3 @@ class GoogleGenerativeAIConversationEntity(
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View File

@ -18,9 +18,8 @@
"step": { "step": {
"init": { "init": {
"data": { "data": {
"recommended": "Recommended settings", "recommended": "Recommended model settings",
"prompt": "Prompt", "prompt": "Instructions",
"tone_prompt": "Tone",
"chat_model": "[%key:common::generic::model%]", "chat_model": "[%key:common::generic::model%]",
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P", "top_p": "Top P",
@ -29,8 +28,7 @@
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
}, },
"data_description": { "data_description": {
"prompt": "Extra data to provide to the LLM. This can be a template.", "prompt": "Instruct how the LLM should respond. This can be a template."
"tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template."
} }
} }
} }

View File

@ -5,23 +5,7 @@ import logging
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
An overview of the areas and the devices in this smart home:
{%- for area in areas() %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
{%- if not area_info.printed %}
{{ area_name(area) }}:
{%- set area_info.printed = true %}
{%- endif %}
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %}
{%- endfor %}
{%- endfor %}
"""
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o" DEFAULT_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens" CONF_MAX_TOKENS = "max_tokens"

View File

@ -110,7 +110,6 @@ class OpenAIConversationEntity(
) )
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
@ -122,10 +121,33 @@ class OpenAIConversationEntity(
else: else:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
try: try:
prompt = self._async_generate_prompt( prompt = template.Template(
raw_prompt, self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
llm_api, ).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
) )
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
prompt = (
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n"
+ prompt
)
except TemplateError as err: except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err) LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
@ -136,6 +158,7 @@ class OpenAIConversationEntity(
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
messages = [{"role": "system", "content": prompt}] messages = [{"role": "system", "content": prompt}]
messages.append({"role": "user", "content": user_input.text}) messages.append({"role": "user", "content": user_input.text})
@ -213,22 +236,3 @@ class OpenAIConversationEntity(
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
def _async_generate_prompt(
self,
raw_prompt: str,
llm_api: llm.API | None,
) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

View File

@ -102,7 +102,11 @@ class API(ABC):
hass: HomeAssistant hass: HomeAssistant
id: str id: str
name: str name: str
prompt_template: str
@abstractmethod
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
raise NotImplementedError
@abstractmethod @abstractmethod
@callback @callback
@ -183,9 +187,12 @@ class AssistAPI(API):
hass=hass, hass=hass,
id=LLM_API_ASSIST, id=LLM_API_ASSIST,
name="Assist", name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
) )
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
return "Call the intent tools to control Home Assistant. Just pass the name to the intent."
@callback @callback
def async_get_tools(self) -> list[Tool]: def async_get_tools(self) -> list[Tool]:
"""Return a list of LLM tools.""" """Return a list of LLM tools."""

View File

@ -23,22 +23,7 @@
dict({ dict({
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': 'Answer in plain text. Keep it simple and to the point.',
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user', 'role': 'user',
}), }),
dict({ dict({
@ -82,22 +67,7 @@
dict({ dict({
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': 'Answer in plain text. Keep it simple and to the point.',
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user', 'role': 'user',
}), }),
dict({ dict({
@ -142,20 +112,8 @@
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': '''
This smart home is controlled by Home Assistant. Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''', ''',
'role': 'user', 'role': 'user',
}), }),
@ -201,20 +159,8 @@
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': '''
This smart home is controlled by Home Assistant. Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''', ''',
'role': 'user', 'role': 'user',
}), }),

View File

@ -13,7 +13,6 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_PROMPT, CONF_PROMPT,
CONF_RECOMMENDED, CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
DOMAIN, DOMAIN,
@ -90,7 +89,7 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["options"] == { assert result2["options"] == {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST, CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_TONE_PROMPT: "", CONF_PROMPT: "",
} }
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -102,7 +101,7 @@ async def test_form(hass: HomeAssistant) -> None:
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "none", CONF_LLM_HASS_API: "none",
CONF_TONE_PROMPT: "bla", CONF_PROMPT: "bla",
}, },
{ {
CONF_RECOMMENDED: False, CONF_RECOMMENDED: False,
@ -132,12 +131,12 @@ async def test_form(hass: HomeAssistant) -> None:
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist", CONF_LLM_HASS_API: "assist",
CONF_TONE_PROMPT: "", CONF_PROMPT: "",
}, },
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist", CONF_LLM_HASS_API: "assist",
CONF_TONE_PROMPT: "", CONF_PROMPT: "",
}, },
), ),
], ],

View File

@ -11,7 +11,6 @@ from openai.types.chat.chat_completion_message_tool_call import (
Function, Function,
) )
from openai.types.completion_usage import CompletionUsage from openai.types.completion_usage import CompletionUsage
import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
@ -19,148 +18,12 @@ from homeassistant.components import conversation
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import intent, llm
area_registry as ar,
device_registry as dr,
intent,
llm,
)
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.mark.parametrize("agent_id", [None, "conversation.openai"])
@pytest.mark.parametrize(
"config_entry_options", [{}, {CONF_LLM_HASS_API: llm.LLM_API_ASSIST}]
)
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
agent_id: str,
config_entry_options: dict,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
for i in range(3):
area_registry.async_create(f"{i}Empty Area")
if agent_id is None:
agent_id = mock_config_entry.entry_id
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
)
for i in range(3):
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello, how can I help you?",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-3.5-turbo-0613",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
),
) as mock_create:
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=agent_id
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[0][2]["messages"] == snapshot
async def test_error_handling( async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None: ) -> None:

View File

@ -20,11 +20,15 @@ async def test_register_api(hass: HomeAssistant) -> None:
"""Test registering an llm api.""" """Test registering an llm api."""
class MyAPI(llm.API): class MyAPI(llm.API):
async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str:
"""Return a prompt for the tool."""
return ""
def async_get_tools(self) -> list[llm.Tool]: def async_get_tools(self) -> list[llm.Tool]:
"""Return a list of tools.""" """Return a list of tools."""
return [] return []
api = MyAPI(hass=hass, id="test", name="Test", prompt_template="") api = MyAPI(hass=hass, id="test", name="Test")
llm.async_register_api(hass, api) llm.async_register_api(hass, api)
assert llm.async_get_api(hass, "test") is api assert llm.async_get_api(hass, "test") is api