Add Code Interpreter tool for OpenAI Conversation (#148383)

This commit is contained in:
Denis Shulyaka 2025-07-16 19:52:53 +07:00 committed by GitHub
parent 0d79f7db51
commit 3e465da892
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 206 additions and 24 deletions

View File

@ -42,6 +42,7 @@ from homeassistant.helpers.typing import VolDictType
from .const import (
CONF_CHAT_MODEL,
CONF_CODE_INTERPRETER,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_REASONING_EFFORT,
@ -60,6 +61,7 @@ from .const import (
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CODE_INTERPRETER,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
@ -312,7 +314,12 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
options = self.options
errors: dict[str, str] = {}
step_schema: VolDictType = {}
step_schema: VolDictType = {
vol.Optional(
CONF_CODE_INTERPRETER,
default=RECOMMENDED_CODE_INTERPRETER,
): bool,
}
model = options[CONF_CHAT_MODEL]
@ -375,18 +382,6 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
)
}
if not step_schema:
if self._is_new:
return self.async_create_entry(
title=options.pop(CONF_NAME),
data=options,
)
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=options,
)
if user_input is not None:
if user_input.get(CONF_WEB_SEARCH):
if user_input.get(CONF_WEB_SEARCH_USER_LOCATION):

View File

@ -13,6 +13,7 @@ DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
DEFAULT_NAME = "OpenAI Conversation"
CONF_CHAT_MODEL = "chat_model"
CONF_CODE_INTERPRETER = "code_interpreter"
CONF_FILENAMES = "filenames"
CONF_MAX_TOKENS = "max_tokens"
CONF_PROMPT = "prompt"
@ -27,6 +28,7 @@ CONF_WEB_SEARCH_CITY = "city"
CONF_WEB_SEARCH_REGION = "region"
CONF_WEB_SEARCH_COUNTRY = "country"
CONF_WEB_SEARCH_TIMEZONE = "timezone"
RECOMMENDED_CODE_INTERPRETER = False
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
RECOMMENDED_MAX_TOKENS = 3000
RECOMMENDED_REASONING_EFFORT = "low"

View File

@ -38,6 +38,10 @@ from openai.types.responses import (
WebSearchToolParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.tool_param import (
CodeInterpreter,
CodeInterpreterContainerCodeInterpreterToolAuto,
)
from openai.types.responses.web_search_tool_param import UserLocation
import voluptuous as vol
from voluptuous_openapi import convert
@ -52,6 +56,7 @@ from homeassistant.util import slugify
from .const import (
CONF_CHAT_MODEL,
CONF_CODE_INTERPRETER,
CONF_MAX_TOKENS,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
@ -292,7 +297,7 @@ class OpenAIBaseLLMEntity(Entity):
"""Generate an answer for the chat log."""
options = self.subentry.data
tools: list[ToolParam] | None = None
tools: list[ToolParam] = []
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
@ -314,10 +319,18 @@ class OpenAIBaseLLMEntity(Entity):
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
)
if tools is None:
tools = []
tools.append(web_search)
if options.get(CONF_CODE_INTERPRETER):
tools.append(
CodeInterpreter(
type="code_interpreter",
container=CodeInterpreterContainerCodeInterpreterToolAuto(
type="auto"
),
)
)
model_args = {
"model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
"input": [],

View File

@ -48,12 +48,14 @@
"model": {
"title": "Model-specific options",
"data": {
"code_interpreter": "Enable code interpreter tool",
"reasoning_effort": "Reasoning effort",
"web_search": "Enable web search",
"search_context_size": "Search context size",
"user_location": "Include home location"
},
"data_description": {
"code_interpreter": "This tool, also known as the python tool to the model, allows it to run code to answer questions",
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt",
"web_search": "Allow the model to search the web for the latest information before generating a response",
"search_context_size": "High level guidance for the amount of context window space to use for the search",

View File

@ -1,6 +1,12 @@
"""Tests for the OpenAI Conversation integration."""
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseCodeInterpreterToolCall,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseFunctionCallArgumentsDeltaEvent,
@ -239,3 +245,86 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve
type="response.output_item.done",
),
]
def create_code_interpreter_item(
id: str, code: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(code, str):
code = [code]
container_id = "cntr_A"
events = [
ResponseOutputItemAddedEvent(
item=ResponseCodeInterpreterToolCall(
id=id,
code="",
container_id=container_id,
outputs=None,
type="code_interpreter_call",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseCodeInterpreterCallInProgressEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.code_interpreter_call.in_progress",
),
]
events.extend(
ResponseCodeInterpreterCallCodeDeltaEvent(
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.code_interpreter_call_code.delta",
)
for delta in code
)
code = "".join(code)
events.extend(
[
ResponseCodeInterpreterCallCodeDoneEvent(
item_id=id,
output_index=output_index,
code=code,
sequence_number=0,
type="response.code_interpreter_call_code.done",
),
ResponseCodeInterpreterCallInterpretingEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.code_interpreter_call.interpreting",
),
ResponseCodeInterpreterCallCompletedEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.code_interpreter_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseCodeInterpreterToolCall(
id=id,
code=code,
container_id=container_id,
outputs=None,
status="completed",
type="code_interpreter_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
)
return events

View File

@ -156,9 +156,10 @@ def mock_create_stream() -> Generator[AsyncMock]:
)
yield ResponseInProgressEvent(
response=response,
sequence_number=0,
sequence_number=1,
type="response.in_progress",
)
sequence_number = 2
response.status = "completed"
for value in events:
@ -173,6 +174,8 @@ def mock_create_stream() -> Generator[AsyncMock]:
response.error = value
break
value.sequence_number = sequence_number
sequence_number += 1
yield value
if isinstance(value, ResponseErrorEvent):
@ -181,19 +184,19 @@ def mock_create_stream() -> Generator[AsyncMock]:
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
sequence_number=0,
sequence_number=sequence_number,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
sequence_number=0,
sequence_number=sequence_number,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
sequence_number=0,
sequence_number=sequence_number,
type="response.completed",
)

View File

@ -13,6 +13,7 @@ from homeassistant.components.openai_conversation.config_flow import (
)
from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL,
CONF_CODE_INTERPRETER,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_REASONING_EFFORT,
@ -311,6 +312,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
},
{
CONF_REASONING_EFFORT: "high",
CONF_CODE_INTERPRETER: True,
},
),
{
@ -321,6 +323,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: 10000,
CONF_REASONING_EFFORT: "high",
CONF_CODE_INTERPRETER: True,
},
),
( # options for web search without user location
@ -343,6 +346,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: False,
CONF_CODE_INTERPRETER: False,
},
),
{
@ -355,6 +359,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: False,
CONF_CODE_INTERPRETER: False,
},
),
# Test that current options are showed as suggested values
@ -373,6 +378,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH_REGION: "California",
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
CONF_CODE_INTERPRETER: True,
},
(
{
@ -389,6 +395,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: False,
CONF_CODE_INTERPRETER: True,
},
),
{
@ -401,6 +408,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: False,
CONF_CODE_INTERPRETER: True,
},
),
( # Case 2: reasoning model
@ -424,7 +432,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: 0.9,
CONF_MAX_TOKENS: 1000,
},
{CONF_REASONING_EFFORT: "high"},
{CONF_REASONING_EFFORT: "high", CONF_CODE_INTERPRETER: False},
),
{
CONF_RECOMMENDED: False,
@ -434,6 +442,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: 0.9,
CONF_MAX_TOKENS: 1000,
CONF_REASONING_EFFORT: "high",
CONF_CODE_INTERPRETER: False,
},
),
# Test that old options are removed after reconfiguration
@ -445,6 +454,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_CHAT_MODEL: "gpt-4o",
CONF_TOP_P: 0.9,
CONF_MAX_TOKENS: 1000,
CONF_CODE_INTERPRETER: True,
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: True,
@ -476,6 +486,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: 0.9,
CONF_MAX_TOKENS: 1000,
CONF_REASONING_EFFORT: "high",
CONF_CODE_INTERPRETER: True,
},
(
{
@ -504,6 +515,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH_REGION: "California",
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
CONF_CODE_INTERPRETER: True,
},
(
{
@ -518,6 +530,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
},
{
CONF_REASONING_EFFORT: "low",
CONF_CODE_INTERPRETER: True,
},
),
{
@ -528,6 +541,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: 0.9,
CONF_MAX_TOKENS: 1000,
CONF_REASONING_EFFORT: "low",
CONF_CODE_INTERPRETER: True,
},
),
( # Case 4: reasoning to web search
@ -540,6 +554,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: 0.9,
CONF_MAX_TOKENS: 1000,
CONF_REASONING_EFFORT: "low",
CONF_CODE_INTERPRETER: True,
},
(
{
@ -556,6 +571,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
CONF_WEB_SEARCH_USER_LOCATION: False,
CONF_CODE_INTERPRETER: False,
},
),
{
@ -568,6 +584,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
CONF_WEB_SEARCH_USER_LOCATION: False,
CONF_CODE_INTERPRETER: False,
},
),
],
@ -718,6 +735,7 @@ async def test_subentry_web_search_user_location(
CONF_WEB_SEARCH_REGION: "California",
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
CONF_CODE_INTERPRETER: False,
}
@ -817,12 +835,24 @@ async def test_creating_ai_task_subentry_advanced(
},
)
assert result3.get("type") is FlowResultType.CREATE_ENTRY
assert result3.get("title") == "Advanced AI Task"
assert result3.get("data") == {
assert result3.get("type") is FlowResultType.FORM
assert result3.get("step_id") == "model"
# Configure model settings
result4 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
CONF_CODE_INTERPRETER: False,
},
)
assert result4.get("type") is FlowResultType.CREATE_ENTRY
assert result4.get("title") == "Advanced AI Task"
assert result4.get("data") == {
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: "gpt-4o",
CONF_MAX_TOKENS: 200,
CONF_TEMPERATURE: 0.5,
CONF_TOP_P: 0.9,
CONF_CODE_INTERPRETER: False,
}

View File

@ -16,6 +16,7 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.openai_conversation.const import (
CONF_CODE_INTERPRETER,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
@ -30,6 +31,7 @@ from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from . import (
create_code_interpreter_item,
create_function_tool_call_item,
create_message_item,
create_reasoning_item,
@ -485,3 +487,49 @@ async def test_web_search(
]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech["plain"]["speech"] == message, result.response.speech
async def test_code_interpreter(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
mock_create_stream,
mock_chat_log: MockChatLog, # noqa: F811
) -> None:
"""Test code_interpreter tool."""
subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry,
subentry,
data={
**subentry.data,
CONF_CODE_INTERPRETER: True,
},
)
await hass.config_entries.async_reload(mock_config_entry.entry_id)
message = "Ive calculated it with Python: the square root of 55555 is approximately 235.70108188126758."
mock_create_stream.return_value = [
(
*create_code_interpreter_item(
id="ci_A",
code=["import", " math", "\n", "math", ".sqrt", "(", "555", "55", ")"],
output_index=0,
),
*create_message_item(id="msg_A", text=message, output_index=1),
)
]
result = await conversation.async_converse(
hass,
"Please use the python tool to calculate square root of 55555",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai_conversation",
)
assert mock_create_stream.mock_calls[0][2]["tools"] == [
{"type": "code_interpreter", "container": {"type": "auto"}}
]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech["plain"]["speech"] == message, result.response.speech