mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add support for OpenAI reasoning models (#137139)
* Add support for OpenAI reasoning models * Apply suggestions from code review * Remove o1-mini* and o1-preview* model support * List unsupported models * Reenable audio models (they also support text)
This commit is contained in:
parent
a6781107df
commit
0f36759a38
@ -24,6 +24,7 @@ from homeassistant.helpers.selector import (
|
|||||||
SelectOptionDict,
|
SelectOptionDict,
|
||||||
SelectSelector,
|
SelectSelector,
|
||||||
SelectSelectorConfig,
|
SelectSelectorConfig,
|
||||||
|
SelectSelectorMode,
|
||||||
TemplateSelector,
|
TemplateSelector,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import VolDictType
|
from homeassistant.helpers.typing import VolDictType
|
||||||
@ -32,14 +33,17 @@ from .const import (
|
|||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
|
CONF_REASONING_EFFORT,
|
||||||
CONF_RECOMMENDED,
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
|
UNSUPPORTED_MODELS,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -124,26 +128,32 @@ class OpenAIOptionsFlow(OptionsFlow):
|
|||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Manage the options."""
|
"""Manage the options."""
|
||||||
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
||||||
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||||
if user_input[CONF_LLM_HASS_API] == "none":
|
if user_input[CONF_LLM_HASS_API] == "none":
|
||||||
user_input.pop(CONF_LLM_HASS_API)
|
user_input.pop(CONF_LLM_HASS_API)
|
||||||
return self.async_create_entry(title="", data=user_input)
|
|
||||||
|
|
||||||
# Re-render the options again, now with the recommended options shown/hidden
|
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
|
||||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
errors[CONF_CHAT_MODEL] = "model_not_supported"
|
||||||
|
else:
|
||||||
|
return self.async_create_entry(title="", data=user_input)
|
||||||
|
else:
|
||||||
|
# Re-render the options again, now with the recommended options shown/hidden
|
||||||
|
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||||
|
|
||||||
options = {
|
options = {
|
||||||
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||||
CONF_PROMPT: user_input[CONF_PROMPT],
|
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||||
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||||
}
|
}
|
||||||
|
|
||||||
schema = openai_config_option_schema(self.hass, options)
|
schema = openai_config_option_schema(self.hass, options)
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="init",
|
step_id="init",
|
||||||
data_schema=vol.Schema(schema),
|
data_schema=vol.Schema(schema),
|
||||||
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -210,6 +220,17 @@ def openai_config_option_schema(
|
|||||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||||
default=RECOMMENDED_TEMPERATURE,
|
default=RECOMMENDED_TEMPERATURE,
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_REASONING_EFFORT,
|
||||||
|
description={"suggested_value": options.get(CONF_REASONING_EFFORT)},
|
||||||
|
default=RECOMMENDED_REASONING_EFFORT,
|
||||||
|
): SelectSelector(
|
||||||
|
SelectSelectorConfig(
|
||||||
|
options=["low", "medium", "high"],
|
||||||
|
translation_key="reasoning_effort",
|
||||||
|
mode=SelectSelectorMode.DROPDOWN,
|
||||||
|
)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return schema
|
return schema
|
||||||
|
@ -15,3 +15,17 @@ CONF_TOP_P = "top_p"
|
|||||||
RECOMMENDED_TOP_P = 1.0
|
RECOMMENDED_TOP_P = 1.0
|
||||||
CONF_TEMPERATURE = "temperature"
|
CONF_TEMPERATURE = "temperature"
|
||||||
RECOMMENDED_TEMPERATURE = 1.0
|
RECOMMENDED_TEMPERATURE = 1.0
|
||||||
|
CONF_REASONING_EFFORT = "reasoning_effort"
|
||||||
|
RECOMMENDED_REASONING_EFFORT = "low"
|
||||||
|
|
||||||
|
UNSUPPORTED_MODELS = [
|
||||||
|
"o1-mini",
|
||||||
|
"o1-mini-2024-09-12",
|
||||||
|
"o1-preview",
|
||||||
|
"o1-preview-2024-09-12",
|
||||||
|
"gpt-4o-realtime-preview",
|
||||||
|
"gpt-4o-realtime-preview-2024-12-17",
|
||||||
|
"gpt-4o-realtime-preview-2024-10-01",
|
||||||
|
"gpt-4o-mini-realtime-preview",
|
||||||
|
"gpt-4o-mini-realtime-preview-2024-12-17",
|
||||||
|
]
|
||||||
|
@ -31,12 +31,14 @@ from .const import (
|
|||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
|
CONF_REASONING_EFFORT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
@ -97,12 +99,15 @@ def _chat_message_convert(
|
|||||||
| conversation.NativeContent[ChatCompletionMessageParam],
|
| conversation.NativeContent[ChatCompletionMessageParam],
|
||||||
) -> ChatCompletionMessageParam:
|
) -> ChatCompletionMessageParam:
|
||||||
"""Convert any native chat message for this agent to the native format."""
|
"""Convert any native chat message for this agent to the native format."""
|
||||||
if message.role == "native":
|
role = message.role
|
||||||
|
if role == "native":
|
||||||
# mypy doesn't understand that checking role ensures content type
|
# mypy doesn't understand that checking role ensures content type
|
||||||
return message.content # type: ignore[return-value]
|
return message.content # type: ignore[return-value]
|
||||||
|
if role == "system":
|
||||||
|
role = "developer"
|
||||||
return cast(
|
return cast(
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
{"role": message.role, "content": message.content},
|
{"role": role, "content": message.content},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -189,6 +194,8 @@ class OpenAIConversationEntity(
|
|||||||
for tool in session.llm_api.tools
|
for tool in session.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
|
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
_chat_message_convert(message) for message in session.async_get_messages()
|
_chat_message_convert(message) for message in session.async_get_messages()
|
||||||
]
|
]
|
||||||
@ -197,16 +204,25 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
model_args = {
|
||||||
result = await client.chat.completions.create(
|
"model": model,
|
||||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
"messages": messages,
|
||||||
messages=messages,
|
"tools": tools or NOT_GIVEN,
|
||||||
tools=tools or NOT_GIVEN,
|
"max_completion_tokens": options.get(
|
||||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
),
|
||||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
user=session.conversation_id,
|
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
|
"user": session.conversation_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.startswith("o"):
|
||||||
|
model_args["reasoning_effort"] = options.get(
|
||||||
|
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await client.chat.completions.create(**model_args)
|
||||||
except openai.OpenAIError as err:
|
except openai.OpenAIError as err:
|
||||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||||
|
@ -23,12 +23,26 @@
|
|||||||
"temperature": "Temperature",
|
"temperature": "Temperature",
|
||||||
"top_p": "Top P",
|
"top_p": "Top P",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||||
"recommended": "Recommended model settings"
|
"recommended": "Recommended model settings",
|
||||||
|
"reasoning_effort": "Reasoning effort"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||||
|
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"model_not_supported": "This model is not supported, please select a different model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"selector": {
|
||||||
|
"reasoning_effort": {
|
||||||
|
"options": {
|
||||||
|
"low": "Low",
|
||||||
|
"medium": "Medium",
|
||||||
|
"high": "High"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"services": {
|
"services": {
|
||||||
|
@ -12,12 +12,14 @@ from homeassistant.components.openai_conversation.const import (
|
|||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
|
CONF_REASONING_EFFORT,
|
||||||
CONF_RECOMMENDED,
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
@ -88,6 +90,27 @@ async def test_options(
|
|||||||
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
async def test_options_unsupported_model(
|
||||||
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test the options form giving error about models not supported."""
|
||||||
|
options_flow = await hass.config_entries.options.async_init(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_CHAT_MODEL: "o1-mini",
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["errors"] == {"chat_model": "model_not_supported"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
[
|
[
|
||||||
@ -148,6 +171,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -158,6 +182,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
CONF_RECOMMENDED: True,
|
CONF_RECOMMENDED: True,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user