Add support for OpenAI reasoning models (#137139)

* Add support for OpenAI reasoning models

* Apply suggestions from code review

* Remove o1-mini* and o1-preview* model support

* List unsupported models

* Reenable audio models (they also support text)
This commit is contained in:
Denis Shulyaka 2025-02-03 00:55:16 +03:00 committed by GitHub
parent a6781107df
commit 0f36759a38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 111 additions and 21 deletions

View File

@ -24,6 +24,7 @@ from homeassistant.helpers.selector import (
SelectOptionDict, SelectOptionDict,
SelectSelector, SelectSelector,
SelectSelectorConfig, SelectSelectorConfig,
SelectSelectorMode,
TemplateSelector, TemplateSelector,
) )
from homeassistant.helpers.typing import VolDictType from homeassistant.helpers.typing import VolDictType
@ -32,14 +33,17 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_RECOMMENDED, CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
UNSUPPORTED_MODELS,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -124,26 +128,32 @@ class OpenAIOptionsFlow(OptionsFlow):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Manage the options.""" """Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none": if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API) user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
# Re-render the options again, now with the recommended options shown/hidden if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
self.last_rendered_recommended = user_input[CONF_RECOMMENDED] errors[CONF_CHAT_MODEL] = "model_not_supported"
else:
return self.async_create_entry(title="", data=user_input)
else:
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = { options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED], CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT], CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API], CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
} }
schema = openai_config_option_schema(self.hass, options) schema = openai_config_option_schema(self.hass, options)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init",
data_schema=vol.Schema(schema), data_schema=vol.Schema(schema),
errors=errors,
) )
@ -210,6 +220,17 @@ def openai_config_option_schema(
description={"suggested_value": options.get(CONF_TEMPERATURE)}, description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE, default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
vol.Optional(
CONF_REASONING_EFFORT,
description={"suggested_value": options.get(CONF_REASONING_EFFORT)},
default=RECOMMENDED_REASONING_EFFORT,
): SelectSelector(
SelectSelectorConfig(
options=["low", "medium", "high"],
translation_key="reasoning_effort",
mode=SelectSelectorMode.DROPDOWN,
)
),
} }
) )
return schema return schema

View File

@ -15,3 +15,17 @@ CONF_TOP_P = "top_p"
RECOMMENDED_TOP_P = 1.0 RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0 RECOMMENDED_TEMPERATURE = 1.0
CONF_REASONING_EFFORT = "reasoning_effort"
RECOMMENDED_REASONING_EFFORT = "low"
UNSUPPORTED_MODELS = [
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"gpt-4o-realtime-preview",
"gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-realtime-preview-2024-10-01",
"gpt-4o-mini-realtime-preview",
"gpt-4o-mini-realtime-preview-2024-12-17",
]

View File

@ -31,12 +31,14 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
) )
@ -97,12 +99,15 @@ def _chat_message_convert(
| conversation.NativeContent[ChatCompletionMessageParam], | conversation.NativeContent[ChatCompletionMessageParam],
) -> ChatCompletionMessageParam: ) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format.""" """Convert any native chat message for this agent to the native format."""
if message.role == "native": role = message.role
if role == "native":
# mypy doesn't understand that checking role ensures content type # mypy doesn't understand that checking role ensures content type
return message.content # type: ignore[return-value] return message.content # type: ignore[return-value]
if role == "system":
role = "developer"
return cast( return cast(
ChatCompletionMessageParam, ChatCompletionMessageParam,
{"role": message.role, "content": message.content}, {"role": role, "content": message.content},
) )
@ -189,6 +194,8 @@ class OpenAIConversationEntity(
for tool in session.llm_api.tools for tool in session.llm_api.tools
] ]
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [ messages = [
_chat_message_convert(message) for message in session.async_get_messages() _chat_message_convert(message) for message in session.async_get_messages()
] ]
@ -197,16 +204,25 @@ class OpenAIConversationEntity(
# To prevent infinite loops, we limit the number of iterations # To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
try: model_args = {
result = await client.chat.completions.create( "model": model,
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), "messages": messages,
messages=messages, "tools": tools or NOT_GIVEN,
tools=tools or NOT_GIVEN, "max_completion_tokens": options.get(
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), ),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
user=session.conversation_id, "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": session.conversation_id,
}
if model.startswith("o"):
model_args["reasoning_effort"] = options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
) )
try:
result = await client.chat.completions.create(**model_args)
except openai.OpenAIError as err: except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err) LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err raise HomeAssistantError("Error talking to OpenAI") from err

View File

@ -23,12 +23,26 @@
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P", "top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings" "recommended": "Recommended model settings",
"reasoning_effort": "Reasoning effort"
}, },
"data_description": { "data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template." "prompt": "Instruct how the LLM should respond. This can be a template.",
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
} }
} }
},
"error": {
"model_not_supported": "This model is not supported, please select a different model"
}
},
"selector": {
"reasoning_effort": {
"options": {
"low": "Low",
"medium": "Medium",
"high": "High"
}
} }
}, },
"services": { "services": {

View File

@ -12,12 +12,14 @@ from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_RECOMMENDED, CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
) )
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
@ -88,6 +90,27 @@ async def test_options(
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
async def test_options_unsupported_model(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options form giving error about models not supported."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
result = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_CHAT_MODEL: "o1-mini",
CONF_LLM_HASS_API: "assist",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"chat_model": "model_not_supported"}
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
[ [
@ -148,6 +171,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P, CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
}, },
), ),
( (
@ -158,6 +182,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P, CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
}, },
{ {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,