mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Add Code Interpreter tool for OpenAI Conversation (#148383)
This commit is contained in:
parent
0d79f7db51
commit
3e465da892
@ -42,6 +42,7 @@ from homeassistant.helpers.typing import VolDictType
|
||||
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CODE_INTERPRETER,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_REASONING_EFFORT,
|
||||
@ -60,6 +61,7 @@ from .const import (
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CODE_INTERPRETER,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
@ -312,7 +314,12 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
||||
options = self.options
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
step_schema: VolDictType = {}
|
||||
step_schema: VolDictType = {
|
||||
vol.Optional(
|
||||
CONF_CODE_INTERPRETER,
|
||||
default=RECOMMENDED_CODE_INTERPRETER,
|
||||
): bool,
|
||||
}
|
||||
|
||||
model = options[CONF_CHAT_MODEL]
|
||||
|
||||
@ -375,18 +382,6 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
||||
)
|
||||
}
|
||||
|
||||
if not step_schema:
|
||||
if self._is_new:
|
||||
return self.async_create_entry(
|
||||
title=options.pop(CONF_NAME),
|
||||
data=options,
|
||||
)
|
||||
return self.async_update_and_abort(
|
||||
self._get_entry(),
|
||||
self._get_reconfigure_subentry(),
|
||||
data=options,
|
||||
)
|
||||
|
||||
if user_input is not None:
|
||||
if user_input.get(CONF_WEB_SEARCH):
|
||||
if user_input.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||
|
@ -13,6 +13,7 @@ DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
|
||||
DEFAULT_NAME = "OpenAI Conversation"
|
||||
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
CONF_CODE_INTERPRETER = "code_interpreter"
|
||||
CONF_FILENAMES = "filenames"
|
||||
CONF_MAX_TOKENS = "max_tokens"
|
||||
CONF_PROMPT = "prompt"
|
||||
@ -27,6 +28,7 @@ CONF_WEB_SEARCH_CITY = "city"
|
||||
CONF_WEB_SEARCH_REGION = "region"
|
||||
CONF_WEB_SEARCH_COUNTRY = "country"
|
||||
CONF_WEB_SEARCH_TIMEZONE = "timezone"
|
||||
RECOMMENDED_CODE_INTERPRETER = False
|
||||
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
|
||||
RECOMMENDED_MAX_TOKENS = 3000
|
||||
RECOMMENDED_REASONING_EFFORT = "low"
|
||||
|
@ -38,6 +38,10 @@ from openai.types.responses import (
|
||||
WebSearchToolParam,
|
||||
)
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||
from openai.types.responses.tool_param import (
|
||||
CodeInterpreter,
|
||||
CodeInterpreterContainerCodeInterpreterToolAuto,
|
||||
)
|
||||
from openai.types.responses.web_search_tool_param import UserLocation
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
@ -52,6 +56,7 @@ from homeassistant.util import slugify
|
||||
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CODE_INTERPRETER,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_TEMPERATURE,
|
||||
@ -292,7 +297,7 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
|
||||
tools: list[ToolParam] | None = None
|
||||
tools: list[ToolParam] = []
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
@ -314,10 +319,18 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||
)
|
||||
if tools is None:
|
||||
tools = []
|
||||
tools.append(web_search)
|
||||
|
||||
if options.get(CONF_CODE_INTERPRETER):
|
||||
tools.append(
|
||||
CodeInterpreter(
|
||||
type="code_interpreter",
|
||||
container=CodeInterpreterContainerCodeInterpreterToolAuto(
|
||||
type="auto"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
model_args = {
|
||||
"model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
"input": [],
|
||||
|
@ -48,12 +48,14 @@
|
||||
"model": {
|
||||
"title": "Model-specific options",
|
||||
"data": {
|
||||
"code_interpreter": "Enable code interpreter tool",
|
||||
"reasoning_effort": "Reasoning effort",
|
||||
"web_search": "Enable web search",
|
||||
"search_context_size": "Search context size",
|
||||
"user_location": "Include home location"
|
||||
},
|
||||
"data_description": {
|
||||
"code_interpreter": "This tool, also known as the python tool to the model, allows it to run code to answer questions",
|
||||
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt",
|
||||
"web_search": "Allow the model to search the web for the latest information before generating a response",
|
||||
"search_context_size": "High level guidance for the amount of context window space to use for the search",
|
||||
|
@ -1,6 +1,12 @@
|
||||
"""Tests for the OpenAI Conversation integration."""
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseCodeInterpreterToolCall,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
@ -239,3 +245,86 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_code_interpreter_item(
|
||||
id: str, code: str | list[str], output_index: int
|
||||
) -> list[ResponseStreamEvent]:
|
||||
"""Create a message item."""
|
||||
if isinstance(code, str):
|
||||
code = [code]
|
||||
|
||||
container_id = "cntr_A"
|
||||
events = [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseCodeInterpreterToolCall(
|
||||
id=id,
|
||||
code="",
|
||||
container_id=container_id,
|
||||
outputs=None,
|
||||
type="code_interpreter_call",
|
||||
status="in_progress",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.added",
|
||||
),
|
||||
ResponseCodeInterpreterCallInProgressEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.code_interpreter_call.in_progress",
|
||||
),
|
||||
]
|
||||
|
||||
events.extend(
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent(
|
||||
delta=delta,
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.code_interpreter_call_code.delta",
|
||||
)
|
||||
for delta in code
|
||||
)
|
||||
|
||||
code = "".join(code)
|
||||
|
||||
events.extend(
|
||||
[
|
||||
ResponseCodeInterpreterCallCodeDoneEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
code=code,
|
||||
sequence_number=0,
|
||||
type="response.code_interpreter_call_code.done",
|
||||
),
|
||||
ResponseCodeInterpreterCallInterpretingEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.code_interpreter_call.interpreting",
|
||||
),
|
||||
ResponseCodeInterpreterCallCompletedEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.code_interpreter_call.completed",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseCodeInterpreterToolCall(
|
||||
id=id,
|
||||
code=code,
|
||||
container_id=container_id,
|
||||
outputs=None,
|
||||
status="completed",
|
||||
type="code_interpreter_call",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return events
|
||||
|
@ -156,9 +156,10 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
||||
)
|
||||
yield ResponseInProgressEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
sequence_number=1,
|
||||
type="response.in_progress",
|
||||
)
|
||||
sequence_number = 2
|
||||
response.status = "completed"
|
||||
|
||||
for value in events:
|
||||
@ -173,6 +174,8 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
||||
response.error = value
|
||||
break
|
||||
|
||||
value.sequence_number = sequence_number
|
||||
sequence_number += 1
|
||||
yield value
|
||||
|
||||
if isinstance(value, ResponseErrorEvent):
|
||||
@ -181,19 +184,19 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
||||
if response.status == "incomplete":
|
||||
yield ResponseIncompleteEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
sequence_number=sequence_number,
|
||||
type="response.incomplete",
|
||||
)
|
||||
elif response.status == "failed":
|
||||
yield ResponseFailedEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
sequence_number=sequence_number,
|
||||
type="response.failed",
|
||||
)
|
||||
else:
|
||||
yield ResponseCompletedEvent(
|
||||
response=response,
|
||||
sequence_number=0,
|
||||
sequence_number=sequence_number,
|
||||
type="response.completed",
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,7 @@ from homeassistant.components.openai_conversation.config_flow import (
|
||||
)
|
||||
from homeassistant.components.openai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_CODE_INTERPRETER,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_REASONING_EFFORT,
|
||||
@ -311,6 +312,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
},
|
||||
{
|
||||
CONF_REASONING_EFFORT: "high",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
),
|
||||
{
|
||||
@ -321,6 +323,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: 10000,
|
||||
CONF_REASONING_EFFORT: "high",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
),
|
||||
( # options for web search without user location
|
||||
@ -343,6 +346,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
},
|
||||
),
|
||||
{
|
||||
@ -355,6 +359,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
},
|
||||
),
|
||||
# Test that current options are showed as suggested values
|
||||
@ -373,6 +378,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH_REGION: "California",
|
||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
(
|
||||
{
|
||||
@ -389,6 +395,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
),
|
||||
{
|
||||
@ -401,6 +408,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
),
|
||||
( # Case 2: reasoning model
|
||||
@ -424,7 +432,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: 0.9,
|
||||
CONF_MAX_TOKENS: 1000,
|
||||
},
|
||||
{CONF_REASONING_EFFORT: "high"},
|
||||
{CONF_REASONING_EFFORT: "high", CONF_CODE_INTERPRETER: False},
|
||||
),
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
@ -434,6 +442,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: 0.9,
|
||||
CONF_MAX_TOKENS: 1000,
|
||||
CONF_REASONING_EFFORT: "high",
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
},
|
||||
),
|
||||
# Test that old options are removed after reconfiguration
|
||||
@ -445,6 +454,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_CHAT_MODEL: "gpt-4o",
|
||||
CONF_TOP_P: 0.9,
|
||||
CONF_MAX_TOKENS: 1000,
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||
@ -476,6 +486,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: 0.9,
|
||||
CONF_MAX_TOKENS: 1000,
|
||||
CONF_REASONING_EFFORT: "high",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
(
|
||||
{
|
||||
@ -504,6 +515,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH_REGION: "California",
|
||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
(
|
||||
{
|
||||
@ -518,6 +530,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
},
|
||||
{
|
||||
CONF_REASONING_EFFORT: "low",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
),
|
||||
{
|
||||
@ -528,6 +541,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: 0.9,
|
||||
CONF_MAX_TOKENS: 1000,
|
||||
CONF_REASONING_EFFORT: "low",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
),
|
||||
( # Case 4: reasoning to web search
|
||||
@ -540,6 +554,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: 0.9,
|
||||
CONF_MAX_TOKENS: 1000,
|
||||
CONF_REASONING_EFFORT: "low",
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
(
|
||||
{
|
||||
@ -556,6 +571,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
},
|
||||
),
|
||||
{
|
||||
@ -568,6 +584,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
},
|
||||
),
|
||||
],
|
||||
@ -718,6 +735,7 @@ async def test_subentry_web_search_user_location(
|
||||
CONF_WEB_SEARCH_REGION: "California",
|
||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
}
|
||||
|
||||
|
||||
@ -817,12 +835,24 @@ async def test_creating_ai_task_subentry_advanced(
|
||||
},
|
||||
)
|
||||
|
||||
assert result3.get("type") is FlowResultType.CREATE_ENTRY
|
||||
assert result3.get("title") == "Advanced AI Task"
|
||||
assert result3.get("data") == {
|
||||
assert result3.get("type") is FlowResultType.FORM
|
||||
assert result3.get("step_id") == "model"
|
||||
|
||||
# Configure model settings
|
||||
result4 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
},
|
||||
)
|
||||
|
||||
assert result4.get("type") is FlowResultType.CREATE_ENTRY
|
||||
assert result4.get("title") == "Advanced AI Task"
|
||||
assert result4.get("data") == {
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_CHAT_MODEL: "gpt-4o",
|
||||
CONF_MAX_TOKENS: 200,
|
||||
CONF_TEMPERATURE: 0.5,
|
||||
CONF_TOP_P: 0.9,
|
||||
CONF_CODE_INTERPRETER: False,
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.openai_conversation.const import (
|
||||
CONF_CODE_INTERPRETER,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
@ -30,6 +31,7 @@ from homeassistant.helpers import intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import (
|
||||
create_code_interpreter_item,
|
||||
create_function_tool_call_item,
|
||||
create_message_item,
|
||||
create_reasoning_item,
|
||||
@ -485,3 +487,49 @@ async def test_web_search(
|
||||
]
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||
|
||||
|
||||
async def test_code_interpreter(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
) -> None:
|
||||
"""Test code_interpreter tool."""
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
hass.config_entries.async_update_subentry(
|
||||
mock_config_entry,
|
||||
subentry,
|
||||
data={
|
||||
**subentry.data,
|
||||
CONF_CODE_INTERPRETER: True,
|
||||
},
|
||||
)
|
||||
await hass.config_entries.async_reload(mock_config_entry.entry_id)
|
||||
|
||||
message = "I’ve calculated it with Python: the square root of 55555 is approximately 235.70108188126758."
|
||||
mock_create_stream.return_value = [
|
||||
(
|
||||
*create_code_interpreter_item(
|
||||
id="ci_A",
|
||||
code=["import", " math", "\n", "math", ".sqrt", "(", "555", "55", ")"],
|
||||
output_index=0,
|
||||
),
|
||||
*create_message_item(id="msg_A", text=message, output_index=1),
|
||||
)
|
||||
]
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please use the python tool to calculate square root of 55555",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.openai_conversation",
|
||||
)
|
||||
|
||||
assert mock_create_stream.mock_calls[0][2]["tools"] == [
|
||||
{"type": "code_interpreter", "container": {"type": "auto"}}
|
||||
]
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||
|
Loading…
x
Reference in New Issue
Block a user