From 3e465da89208633c8b238ad9a89cdac2b763797e Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Wed, 16 Jul 2025 19:52:53 +0700 Subject: [PATCH] Add Code Interpreter tool for OpenAI Conversation (#148383) --- .../openai_conversation/config_flow.py | 21 ++--- .../components/openai_conversation/const.py | 2 + .../components/openai_conversation/entity.py | 19 +++- .../openai_conversation/strings.json | 2 + .../openai_conversation/__init__.py | 89 +++++++++++++++++++ .../openai_conversation/conftest.py | 11 ++- .../openai_conversation/test_config_flow.py | 38 +++++++- .../openai_conversation/test_conversation.py | 48 ++++++++++ 8 files changed, 206 insertions(+), 24 deletions(-) diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index ce6872c7c20..aa1c967ca8f 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -42,6 +42,7 @@ from homeassistant.helpers.typing import VolDictType from .const import ( CONF_CHAT_MODEL, + CONF_CODE_INTERPRETER, CONF_MAX_TOKENS, CONF_PROMPT, CONF_REASONING_EFFORT, @@ -60,6 +61,7 @@ from .const import ( DOMAIN, RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_CHAT_MODEL, + RECOMMENDED_CODE_INTERPRETER, RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_MAX_TOKENS, RECOMMENDED_REASONING_EFFORT, @@ -312,7 +314,12 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow): options = self.options errors: dict[str, str] = {} - step_schema: VolDictType = {} + step_schema: VolDictType = { + vol.Optional( + CONF_CODE_INTERPRETER, + default=RECOMMENDED_CODE_INTERPRETER, + ): bool, + } model = options[CONF_CHAT_MODEL] @@ -375,18 +382,6 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow): ) } - if not step_schema: - if self._is_new: - return self.async_create_entry( - title=options.pop(CONF_NAME), - data=options, - ) - return self.async_update_and_abort( - self._get_entry(), - self._get_reconfigure_subentry(), - data=options, - ) - if user_input is not None: if user_input.get(CONF_WEB_SEARCH): if user_input.get(CONF_WEB_SEARCH_USER_LOCATION): diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index a15f71118c0..cacef6fcff9 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -13,6 +13,7 @@ DEFAULT_AI_TASK_NAME = "OpenAI AI Task" DEFAULT_NAME = "OpenAI Conversation" CONF_CHAT_MODEL = "chat_model" +CONF_CODE_INTERPRETER = "code_interpreter" CONF_FILENAMES = "filenames" CONF_MAX_TOKENS = "max_tokens" CONF_PROMPT = "prompt" @@ -27,6 +28,7 @@ CONF_WEB_SEARCH_CITY = "city" CONF_WEB_SEARCH_REGION = "region" CONF_WEB_SEARCH_COUNTRY = "country" CONF_WEB_SEARCH_TIMEZONE = "timezone" +RECOMMENDED_CODE_INTERPRETER = False RECOMMENDED_CHAT_MODEL = "gpt-4o-mini" RECOMMENDED_MAX_TOKENS = 3000 RECOMMENDED_REASONING_EFFORT = "low" diff --git a/homeassistant/components/openai_conversation/entity.py b/homeassistant/components/openai_conversation/entity.py index 7679bef83f1..93713c78d9c 100644 --- a/homeassistant/components/openai_conversation/entity.py +++ b/homeassistant/components/openai_conversation/entity.py @@ -38,6 +38,10 @@ from openai.types.responses import ( WebSearchToolParam, ) from openai.types.responses.response_input_param import FunctionCallOutput +from openai.types.responses.tool_param import ( + CodeInterpreter, + CodeInterpreterContainerCodeInterpreterToolAuto, +) from openai.types.responses.web_search_tool_param import UserLocation import voluptuous as vol from voluptuous_openapi import convert @@ -52,6 +56,7 @@ from homeassistant.util import slugify from .const import ( CONF_CHAT_MODEL, + CONF_CODE_INTERPRETER, CONF_MAX_TOKENS, CONF_REASONING_EFFORT, CONF_TEMPERATURE, @@ -292,7 +297,7 @@ class OpenAIBaseLLMEntity(Entity): """Generate an answer for the chat log.""" options = self.subentry.data - tools: list[ToolParam] | None = None + tools: list[ToolParam] = [] if chat_log.llm_api: tools = [ _format_tool(tool, chat_log.llm_api.custom_serializer) @@ -314,10 +319,18 @@ class OpenAIBaseLLMEntity(Entity): country=options.get(CONF_WEB_SEARCH_COUNTRY, ""), timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""), ) - if tools is None: - tools = [] tools.append(web_search) + if options.get(CONF_CODE_INTERPRETER): + tools.append( + CodeInterpreter( + type="code_interpreter", + container=CodeInterpreterContainerCodeInterpreterToolAuto( + type="auto" + ), + ) + ) + model_args = { "model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), "input": [], diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 5011fc9cf99..fef955b4fa9 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -48,12 +48,14 @@ "model": { "title": "Model-specific options", "data": { + "code_interpreter": "Enable code interpreter tool", "reasoning_effort": "Reasoning effort", "web_search": "Enable web search", "search_context_size": "Search context size", "user_location": "Include home location" }, "data_description": { + "code_interpreter": "This tool, also known as the python tool to the model, allows it to run code to answer questions", "reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt", "web_search": "Allow the model to search the web for the latest information before generating a response", "search_context_size": "High level guidance for the amount of context window space to use for the search", diff --git a/tests/components/openai_conversation/__init__.py b/tests/components/openai_conversation/__init__.py index 11dc978250a..c10c23df237 100644 --- a/tests/components/openai_conversation/__init__.py +++ b/tests/components/openai_conversation/__init__.py @@ -1,6 +1,12 @@ """Tests for the OpenAI Conversation integration.""" from openai.types.responses import ( + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallCompletedEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterToolCall, ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, ResponseFunctionCallArgumentsDeltaEvent, @@ -239,3 +245,86 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve type="response.output_item.done", ), ] + + +def create_code_interpreter_item( + id: str, code: str | list[str], output_index: int +) -> list[ResponseStreamEvent]: + """Create a message item.""" + if isinstance(code, str): + code = [code] + + container_id = "cntr_A" + events = [ + ResponseOutputItemAddedEvent( + item=ResponseCodeInterpreterToolCall( + id=id, + code="", + container_id=container_id, + outputs=None, + type="code_interpreter_call", + status="in_progress", + ), + output_index=output_index, + sequence_number=0, + type="response.output_item.added", + ), + ResponseCodeInterpreterCallInProgressEvent( + item_id=id, + output_index=output_index, + sequence_number=0, + type="response.code_interpreter_call.in_progress", + ), + ] + + events.extend( + ResponseCodeInterpreterCallCodeDeltaEvent( + delta=delta, + item_id=id, + output_index=output_index, + sequence_number=0, + type="response.code_interpreter_call_code.delta", + ) + for delta in code + ) + + code = "".join(code) + + events.extend( + [ + ResponseCodeInterpreterCallCodeDoneEvent( + item_id=id, + output_index=output_index, + code=code, + sequence_number=0, + type="response.code_interpreter_call_code.done", + ), + ResponseCodeInterpreterCallInterpretingEvent( + item_id=id, + output_index=output_index, + sequence_number=0, + type="response.code_interpreter_call.interpreting", + ), + ResponseCodeInterpreterCallCompletedEvent( + item_id=id, + output_index=output_index, + sequence_number=0, + type="response.code_interpreter_call.completed", + ), + ResponseOutputItemDoneEvent( + item=ResponseCodeInterpreterToolCall( + id=id, + code=code, + container_id=container_id, + outputs=None, + status="completed", + type="code_interpreter_call", + ), + output_index=output_index, + sequence_number=0, + type="response.output_item.done", + ), + ] + ) + + return events diff --git a/tests/components/openai_conversation/conftest.py b/tests/components/openai_conversation/conftest.py index 84c907a7c2e..b58e6c31f38 100644 --- a/tests/components/openai_conversation/conftest.py +++ b/tests/components/openai_conversation/conftest.py @@ -156,9 +156,10 @@ def mock_create_stream() -> Generator[AsyncMock]: ) yield ResponseInProgressEvent( response=response, - sequence_number=0, + sequence_number=1, type="response.in_progress", ) + sequence_number = 2 response.status = "completed" for value in events: @@ -173,6 +174,8 @@ def mock_create_stream() -> Generator[AsyncMock]: response.error = value break + value.sequence_number = sequence_number + sequence_number += 1 yield value if isinstance(value, ResponseErrorEvent): @@ -181,19 +184,19 @@ def mock_create_stream() -> Generator[AsyncMock]: if response.status == "incomplete": yield ResponseIncompleteEvent( response=response, - sequence_number=0, + sequence_number=sequence_number, type="response.incomplete", ) elif response.status == "failed": yield ResponseFailedEvent( response=response, - sequence_number=0, + sequence_number=sequence_number, type="response.failed", ) else: yield ResponseCompletedEvent( response=response, - sequence_number=0, + sequence_number=sequence_number, type="response.completed", ) diff --git a/tests/components/openai_conversation/test_config_flow.py b/tests/components/openai_conversation/test_config_flow.py index 0ccbc39160a..6d8fb143f88 100644 --- a/tests/components/openai_conversation/test_config_flow.py +++ b/tests/components/openai_conversation/test_config_flow.py @@ -13,6 +13,7 @@ from homeassistant.components.openai_conversation.config_flow import ( ) from homeassistant.components.openai_conversation.const import ( CONF_CHAT_MODEL, + CONF_CODE_INTERPRETER, CONF_MAX_TOKENS, CONF_PROMPT, CONF_REASONING_EFFORT, @@ -311,6 +312,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non }, { CONF_REASONING_EFFORT: "high", + CONF_CODE_INTERPRETER: True, }, ), { @@ -321,6 +323,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_TOP_P: RECOMMENDED_TOP_P, CONF_MAX_TOKENS: 10000, CONF_REASONING_EFFORT: "high", + CONF_CODE_INTERPRETER: True, }, ), ( # options for web search without user location @@ -343,6 +346,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "low", CONF_WEB_SEARCH_USER_LOCATION: False, + CONF_CODE_INTERPRETER: False, }, ), { @@ -355,6 +359,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "low", CONF_WEB_SEARCH_USER_LOCATION: False, + CONF_CODE_INTERPRETER: False, }, ), # Test that current options are showed as suggested values @@ -373,6 +378,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH_REGION: "California", CONF_WEB_SEARCH_COUNTRY: "US", CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles", + CONF_CODE_INTERPRETER: True, }, ( { @@ -389,6 +395,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "low", CONF_WEB_SEARCH_USER_LOCATION: False, + CONF_CODE_INTERPRETER: True, }, ), { @@ -401,6 +408,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "low", CONF_WEB_SEARCH_USER_LOCATION: False, + CONF_CODE_INTERPRETER: True, }, ), ( # Case 2: reasoning model @@ -424,7 +432,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_TOP_P: 0.9, CONF_MAX_TOKENS: 1000, }, - {CONF_REASONING_EFFORT: "high"}, + {CONF_REASONING_EFFORT: "high", CONF_CODE_INTERPRETER: False}, ), { CONF_RECOMMENDED: False, @@ -434,6 +442,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_TOP_P: 0.9, CONF_MAX_TOKENS: 1000, CONF_REASONING_EFFORT: "high", + CONF_CODE_INTERPRETER: False, }, ), # Test that old options are removed after reconfiguration @@ -445,6 +454,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_CHAT_MODEL: "gpt-4o", CONF_TOP_P: 0.9, CONF_MAX_TOKENS: 1000, + CONF_CODE_INTERPRETER: True, CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "low", CONF_WEB_SEARCH_USER_LOCATION: True, @@ -476,6 +486,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_TOP_P: 0.9, CONF_MAX_TOKENS: 1000, CONF_REASONING_EFFORT: "high", + CONF_CODE_INTERPRETER: True, }, ( { @@ -504,6 +515,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH_REGION: "California", CONF_WEB_SEARCH_COUNTRY: "US", CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles", + CONF_CODE_INTERPRETER: True, }, ( { @@ -518,6 +530,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non }, { CONF_REASONING_EFFORT: "low", + CONF_CODE_INTERPRETER: True, }, ), { @@ -528,6 +541,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_TOP_P: 0.9, CONF_MAX_TOKENS: 1000, CONF_REASONING_EFFORT: "low", + CONF_CODE_INTERPRETER: True, }, ), ( # Case 4: reasoning to web search @@ -540,6 +554,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_TOP_P: 0.9, CONF_MAX_TOKENS: 1000, CONF_REASONING_EFFORT: "low", + CONF_CODE_INTERPRETER: True, }, ( { @@ -556,6 +571,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "high", CONF_WEB_SEARCH_USER_LOCATION: False, + CONF_CODE_INTERPRETER: False, }, ), { @@ -568,6 +584,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non CONF_WEB_SEARCH: True, CONF_WEB_SEARCH_CONTEXT_SIZE: "high", CONF_WEB_SEARCH_USER_LOCATION: False, + CONF_CODE_INTERPRETER: False, }, ), ], @@ -718,6 +735,7 @@ async def test_subentry_web_search_user_location( CONF_WEB_SEARCH_REGION: "California", CONF_WEB_SEARCH_COUNTRY: "US", CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles", + CONF_CODE_INTERPRETER: False, } @@ -817,12 +835,24 @@ async def test_creating_ai_task_subentry_advanced( }, ) - assert result3.get("type") is FlowResultType.CREATE_ENTRY - assert result3.get("title") == "Advanced AI Task" - assert result3.get("data") == { + assert result3.get("type") is FlowResultType.FORM + assert result3.get("step_id") == "model" + + # Configure model settings + result4 = await hass.config_entries.subentries.async_configure( + result["flow_id"], + { + CONF_CODE_INTERPRETER: False, + }, + ) + + assert result4.get("type") is FlowResultType.CREATE_ENTRY + assert result4.get("title") == "Advanced AI Task" + assert result4.get("data") == { CONF_RECOMMENDED: False, CONF_CHAT_MODEL: "gpt-4o", CONF_MAX_TOKENS: 200, CONF_TEMPERATURE: 0.5, CONF_TOP_P: 0.9, + CONF_CODE_INTERPRETER: False, } diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 39cd129e1ba..dafcba7bfeb 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -16,6 +16,7 @@ from syrupy.assertion import SnapshotAssertion from homeassistant.components import conversation from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.openai_conversation.const import ( + CONF_CODE_INTERPRETER, CONF_WEB_SEARCH, CONF_WEB_SEARCH_CITY, CONF_WEB_SEARCH_CONTEXT_SIZE, @@ -30,6 +31,7 @@ from homeassistant.helpers import intent from homeassistant.setup import async_setup_component from . import ( + create_code_interpreter_item, create_function_tool_call_item, create_message_item, create_reasoning_item, @@ -485,3 +487,49 @@ async def test_web_search( ] assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.speech["plain"]["speech"] == message, result.response.speech + + +async def test_code_interpreter( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + mock_create_stream, + mock_chat_log: MockChatLog, # noqa: F811 +) -> None: + """Test code_interpreter tool.""" + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( + mock_config_entry, + subentry, + data={ + **subentry.data, + CONF_CODE_INTERPRETER: True, + }, + ) + await hass.config_entries.async_reload(mock_config_entry.entry_id) + + message = "I’ve calculated it with Python: the square root of 55555 is approximately 235.70108188126758." + mock_create_stream.return_value = [ + ( + *create_code_interpreter_item( + id="ci_A", + code=["import", " math", "\n", "math", ".sqrt", "(", "555", "55", ")"], + output_index=0, + ), + *create_message_item(id="msg_A", text=message, output_index=1), + ) + ] + + result = await conversation.async_converse( + hass, + "Please use the python tool to calculate square root of 55555", + mock_chat_log.conversation_id, + Context(), + agent_id="conversation.openai_conversation", + ) + + assert mock_create_stream.mock_calls[0][2]["tools"] == [ + {"type": "code_interpreter", "container": {"type": "auto"}} + ] + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert result.response.speech["plain"]["speech"] == message, result.response.speech