Add recommended model options to OpenAI (#118083)

* Add recommended options to OpenAI

* Use string join
This commit is contained in:
Paulus Schoutsen 2024-05-25 00:01:48 -04:00 committed by GitHub
parent c59d4f9bba
commit 676fe5a9a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 192 additions and 77 deletions

View File

@ -31,14 +31,15 @@ from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)
_LOGGER = logging.getLogger(__name__)
@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
}
)
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry(
title="ChatGPT",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
options=RECOMMENDED_OPTIONS,
)
return self.async_show_form(
@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False
)
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = openai_config_option_schema(self.hass, self.config_entry.options)
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
schema = openai_config_option_schema(self.hass, options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow):
def openai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
options: dict[str, Any] | MappingProxyType[str, Any],
) -> dict:
"""Return a schema for OpenAI completion options."""
apis: list[SelectOptionDict] = [
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
hass_apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
@ -144,38 +167,46 @@ def openai_config_option_schema(
for api in llm.async_get_apis(hass)
)
return {
schema = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
if options.get(CONF_RECOMMENDED):
return schema
schema.update(
{
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
}
)
return schema

View File

@ -4,13 +4,15 @@ import logging
DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_RECOMMENDED = "recommended"
CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o"
RECOMMENDED_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
RECOMMENDED_MAX_TOKENS = 150
CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1.0
RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 1.0
RECOMMENDED_TEMPERATURE = 1.0

View File

@ -22,13 +22,13 @@ from .const import (
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)
# Max number of back and forth with the LLM to generate a response
@ -97,15 +97,14 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
if options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
@ -117,26 +116,12 @@ class OpenAIConversationEntity(
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
@ -149,11 +134,24 @@ class OpenAIConversationEntity(
device_id=user_input.device_id,
)
prompt = (
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n"
+ prompt
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
else:
api_prompt = llm.PROMPT_NO_API_CONFIGURED
prompt = "\n".join(
(
api_prompt,
template.Template(
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
)
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
@ -170,7 +168,7 @@ class OpenAIConversationEntity(
messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt for %s: %s", model, messages)
LOGGER.debug("Prompt: %s", messages)
client = self.hass.data[DOMAIN][self.entry.entry_id]
@ -178,12 +176,12 @@ class OpenAIConversationEntity(
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=model,
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
tools=tools,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=conversation_id,
)
except openai.OpenAIError as err:

View File

@ -22,7 +22,8 @@
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."

View File

@ -9,9 +9,17 @@ import pytest
from homeassistant import config_entries
from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL,
DEFAULT_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@ -75,7 +83,7 @@ async def test_options(
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"]["prompt"] == "Speak like a pirate"
assert options["data"]["max_tokens"] == 200
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
@pytest.mark.parametrize(
@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": error}
@pytest.mark.parametrize(
("current_options", "new_options", "expected_options"),
[
(
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "none",
CONF_PROMPT: "bla",
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
),
(
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
),
],
)
async def test_options_switching(
hass: HomeAssistant,
mock_config_entry,
mock_init_component,
current_options,
new_options,
expected_options,
) -> None:
"""Test the options form."""
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
**current_options,
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
},
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
new_options,
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options