mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add recommended model options to OpenAI (#118083)
* Add recommended options to OpenAI * Use string join
This commit is contained in:
parent
c59d4f9bba
commit
676fe5a9a2
@ -31,14 +31,15 @@ from .const import (
|
|||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DEFAULT_CHAT_MODEL,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
|
||||||
DEFAULT_TOP_P,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RECOMMENDED_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||||
|
CONF_PROMPT: DEFAULT_PROMPT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||||
"""Validate the user input allows us to connect.
|
"""Validate the user input allows us to connect.
|
||||||
@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title="ChatGPT",
|
title="ChatGPT",
|
||||||
data=user_input,
|
data=user_input,
|
||||||
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
options=RECOMMENDED_OPTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
|
|||||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||||
"""Initialize options flow."""
|
"""Initialize options flow."""
|
||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
|
self.last_rendered_recommended = config_entry.options.get(
|
||||||
|
CONF_RECOMMENDED, False
|
||||||
|
)
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Manage the options."""
|
"""Manage the options."""
|
||||||
|
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if user_input[CONF_LLM_HASS_API] == "none":
|
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||||
user_input.pop(CONF_LLM_HASS_API)
|
if user_input[CONF_LLM_HASS_API] == "none":
|
||||||
return self.async_create_entry(title="", data=user_input)
|
user_input.pop(CONF_LLM_HASS_API)
|
||||||
schema = openai_config_option_schema(self.hass, self.config_entry.options)
|
return self.async_create_entry(title="", data=user_input)
|
||||||
|
|
||||||
|
# Re-render the options again, now with the recommended options shown/hidden
|
||||||
|
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||||
|
|
||||||
|
options = {
|
||||||
|
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||||
|
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||||
|
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = openai_config_option_schema(self.hass, options)
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="init",
|
step_id="init",
|
||||||
data_schema=vol.Schema(schema),
|
data_schema=vol.Schema(schema),
|
||||||
@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow):
|
|||||||
|
|
||||||
def openai_config_option_schema(
|
def openai_config_option_schema(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
options: MappingProxyType[str, Any],
|
options: dict[str, Any] | MappingProxyType[str, Any],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Return a schema for OpenAI completion options."""
|
"""Return a schema for OpenAI completion options."""
|
||||||
apis: list[SelectOptionDict] = [
|
hass_apis: list[SelectOptionDict] = [
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label="No control",
|
label="No control",
|
||||||
value="none",
|
value="none",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
apis.extend(
|
hass_apis.extend(
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label=api.name,
|
label=api.name,
|
||||||
value=api.id,
|
value=api.id,
|
||||||
@ -144,38 +167,46 @@ def openai_config_option_schema(
|
|||||||
for api in llm.async_get_apis(hass)
|
for api in llm.async_get_apis(hass)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
schema = {
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
|
||||||
default=DEFAULT_PROMPT,
|
|
||||||
): TemplateSelector(),
|
): TemplateSelector(),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_LLM_HASS_API,
|
CONF_LLM_HASS_API,
|
||||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||||
default="none",
|
default="none",
|
||||||
): SelectSelector(SelectSelectorConfig(options=apis)),
|
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||||
vol.Optional(
|
vol.Required(
|
||||||
CONF_CHAT_MODEL,
|
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||||
description={
|
): bool,
|
||||||
# New key in HA 2023.4
|
|
||||||
"suggested_value": options.get(CONF_CHAT_MODEL)
|
|
||||||
},
|
|
||||||
default=DEFAULT_CHAT_MODEL,
|
|
||||||
): str,
|
|
||||||
vol.Optional(
|
|
||||||
CONF_MAX_TOKENS,
|
|
||||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
|
||||||
default=DEFAULT_MAX_TOKENS,
|
|
||||||
): int,
|
|
||||||
vol.Optional(
|
|
||||||
CONF_TOP_P,
|
|
||||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
|
||||||
default=DEFAULT_TOP_P,
|
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
|
||||||
vol.Optional(
|
|
||||||
CONF_TEMPERATURE,
|
|
||||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
|
||||||
default=DEFAULT_TEMPERATURE,
|
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if options.get(CONF_RECOMMENDED):
|
||||||
|
return schema
|
||||||
|
|
||||||
|
schema.update(
|
||||||
|
{
|
||||||
|
vol.Optional(
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
|
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||||
|
default=RECOMMENDED_CHAT_MODEL,
|
||||||
|
): str,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_MAX_TOKENS,
|
||||||
|
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||||
|
default=RECOMMENDED_MAX_TOKENS,
|
||||||
|
): int,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_TOP_P,
|
||||||
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||||
|
default=RECOMMENDED_TOP_P,
|
||||||
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_TEMPERATURE,
|
||||||
|
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||||
|
default=RECOMMENDED_TEMPERATURE,
|
||||||
|
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return schema
|
||||||
|
@ -4,13 +4,15 @@ import logging
|
|||||||
|
|
||||||
DOMAIN = "openai_conversation"
|
DOMAIN = "openai_conversation"
|
||||||
LOGGER = logging.getLogger(__package__)
|
LOGGER = logging.getLogger(__package__)
|
||||||
|
|
||||||
|
CONF_RECOMMENDED = "recommended"
|
||||||
CONF_PROMPT = "prompt"
|
CONF_PROMPT = "prompt"
|
||||||
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
|
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
|
||||||
CONF_CHAT_MODEL = "chat_model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
DEFAULT_CHAT_MODEL = "gpt-4o"
|
RECOMMENDED_CHAT_MODEL = "gpt-4o"
|
||||||
CONF_MAX_TOKENS = "max_tokens"
|
CONF_MAX_TOKENS = "max_tokens"
|
||||||
DEFAULT_MAX_TOKENS = 150
|
RECOMMENDED_MAX_TOKENS = 150
|
||||||
CONF_TOP_P = "top_p"
|
CONF_TOP_P = "top_p"
|
||||||
DEFAULT_TOP_P = 1.0
|
RECOMMENDED_TOP_P = 1.0
|
||||||
CONF_TEMPERATURE = "temperature"
|
CONF_TEMPERATURE = "temperature"
|
||||||
DEFAULT_TEMPERATURE = 1.0
|
RECOMMENDED_TEMPERATURE = 1.0
|
||||||
|
@ -22,13 +22,13 @@ from .const import (
|
|||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DEFAULT_CHAT_MODEL,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
|
||||||
DEFAULT_TOP_P,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
@ -97,15 +97,14 @@ class OpenAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
options = self.entry.options
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
llm_api: llm.API | None = None
|
llm_api: llm.API | None = None
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
if options.get(CONF_LLM_HASS_API):
|
||||||
try:
|
try:
|
||||||
llm_api = llm.async_get_api(
|
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
|
||||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
|
||||||
)
|
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
LOGGER.error("Error getting LLM API: %s", err)
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
@ -117,26 +116,12 @@ class OpenAIConversationEntity(
|
|||||||
)
|
)
|
||||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
||||||
|
|
||||||
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
|
||||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
||||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
||||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
||||||
|
|
||||||
if user_input.conversation_id in self.history:
|
if user_input.conversation_id in self.history:
|
||||||
conversation_id = user_input.conversation_id
|
conversation_id = user_input.conversation_id
|
||||||
messages = self.history[conversation_id]
|
messages = self.history[conversation_id]
|
||||||
else:
|
else:
|
||||||
conversation_id = ulid.ulid_now()
|
conversation_id = ulid.ulid_now()
|
||||||
try:
|
try:
|
||||||
prompt = template.Template(
|
|
||||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
|
||||||
).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if llm_api:
|
if llm_api:
|
||||||
empty_tool_input = llm.ToolInput(
|
empty_tool_input = llm.ToolInput(
|
||||||
tool_name="",
|
tool_name="",
|
||||||
@ -149,11 +134,24 @@ class OpenAIConversationEntity(
|
|||||||
device_id=user_input.device_id,
|
device_id=user_input.device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = (
|
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
|
||||||
await llm_api.async_get_api_prompt(empty_tool_input)
|
|
||||||
+ "\n"
|
else:
|
||||||
+ prompt
|
api_prompt = llm.PROMPT_NO_API_CONFIGURED
|
||||||
|
|
||||||
|
prompt = "\n".join(
|
||||||
|
(
|
||||||
|
api_prompt,
|
||||||
|
template.Template(
|
||||||
|
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||||
|
).async_render(
|
||||||
|
{
|
||||||
|
"ha_name": self.hass.config.location_name,
|
||||||
|
},
|
||||||
|
parse_result=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
except TemplateError as err:
|
except TemplateError as err:
|
||||||
LOGGER.error("Error rendering prompt: %s", err)
|
LOGGER.error("Error rendering prompt: %s", err)
|
||||||
@ -170,7 +168,7 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
messages.append({"role": "user", "content": user_input.text})
|
messages.append({"role": "user", "content": user_input.text})
|
||||||
|
|
||||||
LOGGER.debug("Prompt for %s: %s", model, messages)
|
LOGGER.debug("Prompt: %s", messages)
|
||||||
|
|
||||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||||
|
|
||||||
@ -178,12 +176,12 @@ class OpenAIConversationEntity(
|
|||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
try:
|
||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(
|
||||||
model=model,
|
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
max_tokens=max_tokens,
|
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
top_p=top_p,
|
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
temperature=temperature,
|
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
user=conversation_id,
|
user=conversation_id,
|
||||||
)
|
)
|
||||||
except openai.OpenAIError as err:
|
except openai.OpenAIError as err:
|
||||||
|
@ -22,7 +22,8 @@
|
|||||||
"max_tokens": "Maximum tokens to return in response",
|
"max_tokens": "Maximum tokens to return in response",
|
||||||
"temperature": "Temperature",
|
"temperature": "Temperature",
|
||||||
"top_p": "Top P",
|
"top_p": "Top P",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||||
|
"recommended": "Recommended model settings"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||||
|
@ -9,9 +9,17 @@ import pytest
|
|||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
DEFAULT_CHAT_MODEL,
|
CONF_MAX_TOKENS,
|
||||||
|
CONF_PROMPT,
|
||||||
|
CONF_RECOMMENDED,
|
||||||
|
CONF_TEMPERATURE,
|
||||||
|
CONF_TOP_P,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
@ -75,7 +83,7 @@ async def test_options(
|
|||||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert options["data"]["prompt"] == "Speak like a pirate"
|
assert options["data"]["prompt"] == "Speak like a pirate"
|
||||||
assert options["data"]["max_tokens"] == 200
|
assert options["data"]["max_tokens"] == 200
|
||||||
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
|
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
|
|
||||||
assert result2["type"] is FlowResultType.FORM
|
assert result2["type"] is FlowResultType.FORM
|
||||||
assert result2["errors"] == {"base": error}
|
assert result2["errors"] == {"base": error}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("current_options", "new_options", "expected_options"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "none",
|
||||||
|
CONF_PROMPT: "bla",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
CONF_PROMPT: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
CONF_PROMPT: "",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_options_switching(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry,
|
||||||
|
mock_init_component,
|
||||||
|
current_options,
|
||||||
|
new_options,
|
||||||
|
expected_options,
|
||||||
|
) -> None:
|
||||||
|
"""Test the options form."""
|
||||||
|
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
|
||||||
|
options_flow = await hass.config_entries.options.async_init(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
|
||||||
|
options_flow = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
{
|
||||||
|
**current_options,
|
||||||
|
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
options = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
new_options,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert options["data"] == expected_options
|
||||||
|
Loading…
x
Reference in New Issue
Block a user