mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 03:07:37 +00:00
Add Code Interpreter tool for OpenAI Conversation (#148383)
This commit is contained in:
parent
0d79f7db51
commit
3e465da892
@ -42,6 +42,7 @@ from homeassistant.helpers.typing import VolDictType
|
|||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_CODE_INTERPRETER,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_REASONING_EFFORT,
|
CONF_REASONING_EFFORT,
|
||||||
@ -60,6 +61,7 @@ from .const import (
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_AI_TASK_OPTIONS,
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_CODE_INTERPRETER,
|
||||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_REASONING_EFFORT,
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
@ -312,7 +314,12 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
options = self.options
|
options = self.options
|
||||||
errors: dict[str, str] = {}
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
step_schema: VolDictType = {}
|
step_schema: VolDictType = {
|
||||||
|
vol.Optional(
|
||||||
|
CONF_CODE_INTERPRETER,
|
||||||
|
default=RECOMMENDED_CODE_INTERPRETER,
|
||||||
|
): bool,
|
||||||
|
}
|
||||||
|
|
||||||
model = options[CONF_CHAT_MODEL]
|
model = options[CONF_CHAT_MODEL]
|
||||||
|
|
||||||
@ -375,18 +382,6 @@ class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if not step_schema:
|
|
||||||
if self._is_new:
|
|
||||||
return self.async_create_entry(
|
|
||||||
title=options.pop(CONF_NAME),
|
|
||||||
data=options,
|
|
||||||
)
|
|
||||||
return self.async_update_and_abort(
|
|
||||||
self._get_entry(),
|
|
||||||
self._get_reconfigure_subentry(),
|
|
||||||
data=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if user_input.get(CONF_WEB_SEARCH):
|
if user_input.get(CONF_WEB_SEARCH):
|
||||||
if user_input.get(CONF_WEB_SEARCH_USER_LOCATION):
|
if user_input.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||||
|
@ -13,6 +13,7 @@ DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
|
|||||||
DEFAULT_NAME = "OpenAI Conversation"
|
DEFAULT_NAME = "OpenAI Conversation"
|
||||||
|
|
||||||
CONF_CHAT_MODEL = "chat_model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
|
CONF_CODE_INTERPRETER = "code_interpreter"
|
||||||
CONF_FILENAMES = "filenames"
|
CONF_FILENAMES = "filenames"
|
||||||
CONF_MAX_TOKENS = "max_tokens"
|
CONF_MAX_TOKENS = "max_tokens"
|
||||||
CONF_PROMPT = "prompt"
|
CONF_PROMPT = "prompt"
|
||||||
@ -27,6 +28,7 @@ CONF_WEB_SEARCH_CITY = "city"
|
|||||||
CONF_WEB_SEARCH_REGION = "region"
|
CONF_WEB_SEARCH_REGION = "region"
|
||||||
CONF_WEB_SEARCH_COUNTRY = "country"
|
CONF_WEB_SEARCH_COUNTRY = "country"
|
||||||
CONF_WEB_SEARCH_TIMEZONE = "timezone"
|
CONF_WEB_SEARCH_TIMEZONE = "timezone"
|
||||||
|
RECOMMENDED_CODE_INTERPRETER = False
|
||||||
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
|
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
|
||||||
RECOMMENDED_MAX_TOKENS = 3000
|
RECOMMENDED_MAX_TOKENS = 3000
|
||||||
RECOMMENDED_REASONING_EFFORT = "low"
|
RECOMMENDED_REASONING_EFFORT = "low"
|
||||||
|
@ -38,6 +38,10 @@ from openai.types.responses import (
|
|||||||
WebSearchToolParam,
|
WebSearchToolParam,
|
||||||
)
|
)
|
||||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||||
|
from openai.types.responses.tool_param import (
|
||||||
|
CodeInterpreter,
|
||||||
|
CodeInterpreterContainerCodeInterpreterToolAuto,
|
||||||
|
)
|
||||||
from openai.types.responses.web_search_tool_param import UserLocation
|
from openai.types.responses.web_search_tool_param import UserLocation
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
@ -52,6 +56,7 @@ from homeassistant.util import slugify
|
|||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_CODE_INTERPRETER,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_REASONING_EFFORT,
|
CONF_REASONING_EFFORT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
@ -292,7 +297,7 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
|
|
||||||
tools: list[ToolParam] | None = None
|
tools: list[ToolParam] = []
|
||||||
if chat_log.llm_api:
|
if chat_log.llm_api:
|
||||||
tools = [
|
tools = [
|
||||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
@ -314,10 +319,18 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||||
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||||
)
|
)
|
||||||
if tools is None:
|
|
||||||
tools = []
|
|
||||||
tools.append(web_search)
|
tools.append(web_search)
|
||||||
|
|
||||||
|
if options.get(CONF_CODE_INTERPRETER):
|
||||||
|
tools.append(
|
||||||
|
CodeInterpreter(
|
||||||
|
type="code_interpreter",
|
||||||
|
container=CodeInterpreterContainerCodeInterpreterToolAuto(
|
||||||
|
type="auto"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model_args = {
|
model_args = {
|
||||||
"model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
"model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
"input": [],
|
"input": [],
|
||||||
|
@ -48,12 +48,14 @@
|
|||||||
"model": {
|
"model": {
|
||||||
"title": "Model-specific options",
|
"title": "Model-specific options",
|
||||||
"data": {
|
"data": {
|
||||||
|
"code_interpreter": "Enable code interpreter tool",
|
||||||
"reasoning_effort": "Reasoning effort",
|
"reasoning_effort": "Reasoning effort",
|
||||||
"web_search": "Enable web search",
|
"web_search": "Enable web search",
|
||||||
"search_context_size": "Search context size",
|
"search_context_size": "Search context size",
|
||||||
"user_location": "Include home location"
|
"user_location": "Include home location"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
|
"code_interpreter": "This tool, also known as the python tool to the model, allows it to run code to answer questions",
|
||||||
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt",
|
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt",
|
||||||
"web_search": "Allow the model to search the web for the latest information before generating a response",
|
"web_search": "Allow the model to search the web for the latest information before generating a response",
|
||||||
"search_context_size": "High level guidance for the amount of context window space to use for the search",
|
"search_context_size": "High level guidance for the amount of context window space to use for the search",
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
"""Tests for the OpenAI Conversation integration."""
|
"""Tests for the OpenAI Conversation integration."""
|
||||||
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
|
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||||
|
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||||
|
ResponseCodeInterpreterCallCompletedEvent,
|
||||||
|
ResponseCodeInterpreterCallInProgressEvent,
|
||||||
|
ResponseCodeInterpreterCallInterpretingEvent,
|
||||||
|
ResponseCodeInterpreterToolCall,
|
||||||
ResponseContentPartAddedEvent,
|
ResponseContentPartAddedEvent,
|
||||||
ResponseContentPartDoneEvent,
|
ResponseContentPartDoneEvent,
|
||||||
ResponseFunctionCallArgumentsDeltaEvent,
|
ResponseFunctionCallArgumentsDeltaEvent,
|
||||||
@ -239,3 +245,86 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve
|
|||||||
type="response.output_item.done",
|
type="response.output_item.done",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_code_interpreter_item(
|
||||||
|
id: str, code: str | list[str], output_index: int
|
||||||
|
) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a message item."""
|
||||||
|
if isinstance(code, str):
|
||||||
|
code = [code]
|
||||||
|
|
||||||
|
container_id = "cntr_A"
|
||||||
|
events = [
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
item=ResponseCodeInterpreterToolCall(
|
||||||
|
id=id,
|
||||||
|
code="",
|
||||||
|
container_id=container_id,
|
||||||
|
outputs=None,
|
||||||
|
type="code_interpreter_call",
|
||||||
|
status="in_progress",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.added",
|
||||||
|
),
|
||||||
|
ResponseCodeInterpreterCallInProgressEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.code_interpreter_call.in_progress",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
events.extend(
|
||||||
|
ResponseCodeInterpreterCallCodeDeltaEvent(
|
||||||
|
delta=delta,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.code_interpreter_call_code.delta",
|
||||||
|
)
|
||||||
|
for delta in code
|
||||||
|
)
|
||||||
|
|
||||||
|
code = "".join(code)
|
||||||
|
|
||||||
|
events.extend(
|
||||||
|
[
|
||||||
|
ResponseCodeInterpreterCallCodeDoneEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
code=code,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.code_interpreter_call_code.done",
|
||||||
|
),
|
||||||
|
ResponseCodeInterpreterCallInterpretingEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.code_interpreter_call.interpreting",
|
||||||
|
),
|
||||||
|
ResponseCodeInterpreterCallCompletedEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.code_interpreter_call.completed",
|
||||||
|
),
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ResponseCodeInterpreterToolCall(
|
||||||
|
id=id,
|
||||||
|
code=code,
|
||||||
|
container_id=container_id,
|
||||||
|
outputs=None,
|
||||||
|
status="completed",
|
||||||
|
type="code_interpreter_call",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return events
|
||||||
|
@ -156,9 +156,10 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
|||||||
)
|
)
|
||||||
yield ResponseInProgressEvent(
|
yield ResponseInProgressEvent(
|
||||||
response=response,
|
response=response,
|
||||||
sequence_number=0,
|
sequence_number=1,
|
||||||
type="response.in_progress",
|
type="response.in_progress",
|
||||||
)
|
)
|
||||||
|
sequence_number = 2
|
||||||
response.status = "completed"
|
response.status = "completed"
|
||||||
|
|
||||||
for value in events:
|
for value in events:
|
||||||
@ -173,6 +174,8 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
|||||||
response.error = value
|
response.error = value
|
||||||
break
|
break
|
||||||
|
|
||||||
|
value.sequence_number = sequence_number
|
||||||
|
sequence_number += 1
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
if isinstance(value, ResponseErrorEvent):
|
if isinstance(value, ResponseErrorEvent):
|
||||||
@ -181,19 +184,19 @@ def mock_create_stream() -> Generator[AsyncMock]:
|
|||||||
if response.status == "incomplete":
|
if response.status == "incomplete":
|
||||||
yield ResponseIncompleteEvent(
|
yield ResponseIncompleteEvent(
|
||||||
response=response,
|
response=response,
|
||||||
sequence_number=0,
|
sequence_number=sequence_number,
|
||||||
type="response.incomplete",
|
type="response.incomplete",
|
||||||
)
|
)
|
||||||
elif response.status == "failed":
|
elif response.status == "failed":
|
||||||
yield ResponseFailedEvent(
|
yield ResponseFailedEvent(
|
||||||
response=response,
|
response=response,
|
||||||
sequence_number=0,
|
sequence_number=sequence_number,
|
||||||
type="response.failed",
|
type="response.failed",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield ResponseCompletedEvent(
|
yield ResponseCompletedEvent(
|
||||||
response=response,
|
response=response,
|
||||||
sequence_number=0,
|
sequence_number=sequence_number,
|
||||||
type="response.completed",
|
type="response.completed",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from homeassistant.components.openai_conversation.config_flow import (
|
|||||||
)
|
)
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_CODE_INTERPRETER,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_REASONING_EFFORT,
|
CONF_REASONING_EFFORT,
|
||||||
@ -311,6 +312,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
CONF_REASONING_EFFORT: "high",
|
CONF_REASONING_EFFORT: "high",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
@ -321,6 +323,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
CONF_MAX_TOKENS: 10000,
|
CONF_MAX_TOKENS: 10000,
|
||||||
CONF_REASONING_EFFORT: "high",
|
CONF_REASONING_EFFORT: "high",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
( # options for web search without user location
|
( # options for web search without user location
|
||||||
@ -343,6 +346,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH: True,
|
CONF_WEB_SEARCH: True,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
@ -355,6 +359,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH: True,
|
CONF_WEB_SEARCH: True,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
# Test that current options are showed as suggested values
|
# Test that current options are showed as suggested values
|
||||||
@ -373,6 +378,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH_REGION: "California",
|
CONF_WEB_SEARCH_REGION: "California",
|
||||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
@ -389,6 +395,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH: True,
|
CONF_WEB_SEARCH: True,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
@ -401,6 +408,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH: True,
|
CONF_WEB_SEARCH: True,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
( # Case 2: reasoning model
|
( # Case 2: reasoning model
|
||||||
@ -424,7 +432,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: 0.9,
|
CONF_TOP_P: 0.9,
|
||||||
CONF_MAX_TOKENS: 1000,
|
CONF_MAX_TOKENS: 1000,
|
||||||
},
|
},
|
||||||
{CONF_REASONING_EFFORT: "high"},
|
{CONF_REASONING_EFFORT: "high", CONF_CODE_INTERPRETER: False},
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
CONF_RECOMMENDED: False,
|
CONF_RECOMMENDED: False,
|
||||||
@ -434,6 +442,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: 0.9,
|
CONF_TOP_P: 0.9,
|
||||||
CONF_MAX_TOKENS: 1000,
|
CONF_MAX_TOKENS: 1000,
|
||||||
CONF_REASONING_EFFORT: "high",
|
CONF_REASONING_EFFORT: "high",
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
# Test that old options are removed after reconfiguration
|
# Test that old options are removed after reconfiguration
|
||||||
@ -445,6 +454,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_CHAT_MODEL: "gpt-4o",
|
CONF_CHAT_MODEL: "gpt-4o",
|
||||||
CONF_TOP_P: 0.9,
|
CONF_TOP_P: 0.9,
|
||||||
CONF_MAX_TOKENS: 1000,
|
CONF_MAX_TOKENS: 1000,
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
CONF_WEB_SEARCH: True,
|
CONF_WEB_SEARCH: True,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||||
CONF_WEB_SEARCH_USER_LOCATION: True,
|
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||||
@ -476,6 +486,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: 0.9,
|
CONF_TOP_P: 0.9,
|
||||||
CONF_MAX_TOKENS: 1000,
|
CONF_MAX_TOKENS: 1000,
|
||||||
CONF_REASONING_EFFORT: "high",
|
CONF_REASONING_EFFORT: "high",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
@ -504,6 +515,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH_REGION: "California",
|
CONF_WEB_SEARCH_REGION: "California",
|
||||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
@ -518,6 +530,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
CONF_REASONING_EFFORT: "low",
|
CONF_REASONING_EFFORT: "low",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
@ -528,6 +541,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: 0.9,
|
CONF_TOP_P: 0.9,
|
||||||
CONF_MAX_TOKENS: 1000,
|
CONF_MAX_TOKENS: 1000,
|
||||||
CONF_REASONING_EFFORT: "low",
|
CONF_REASONING_EFFORT: "low",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
( # Case 4: reasoning to web search
|
( # Case 4: reasoning to web search
|
||||||
@ -540,6 +554,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: 0.9,
|
CONF_TOP_P: 0.9,
|
||||||
CONF_MAX_TOKENS: 1000,
|
CONF_MAX_TOKENS: 1000,
|
||||||
CONF_REASONING_EFFORT: "low",
|
CONF_REASONING_EFFORT: "low",
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
},
|
},
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
@ -556,6 +571,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH: True,
|
CONF_WEB_SEARCH: True,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
|
||||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
@ -568,6 +584,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_WEB_SEARCH: True,
|
CONF_WEB_SEARCH: True,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "high",
|
||||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -718,6 +735,7 @@ async def test_subentry_web_search_user_location(
|
|||||||
CONF_WEB_SEARCH_REGION: "California",
|
CONF_WEB_SEARCH_REGION: "California",
|
||||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -817,12 +835,24 @@ async def test_creating_ai_task_subentry_advanced(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result3.get("type") is FlowResultType.CREATE_ENTRY
|
assert result3.get("type") is FlowResultType.FORM
|
||||||
assert result3.get("title") == "Advanced AI Task"
|
assert result3.get("step_id") == "model"
|
||||||
assert result3.get("data") == {
|
|
||||||
|
# Configure model settings
|
||||||
|
result4 = await hass.config_entries.subentries.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result4.get("type") is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result4.get("title") == "Advanced AI Task"
|
||||||
|
assert result4.get("data") == {
|
||||||
CONF_RECOMMENDED: False,
|
CONF_RECOMMENDED: False,
|
||||||
CONF_CHAT_MODEL: "gpt-4o",
|
CONF_CHAT_MODEL: "gpt-4o",
|
||||||
CONF_MAX_TOKENS: 200,
|
CONF_MAX_TOKENS: 200,
|
||||||
CONF_TEMPERATURE: 0.5,
|
CONF_TEMPERATURE: 0.5,
|
||||||
CONF_TOP_P: 0.9,
|
CONF_TOP_P: 0.9,
|
||||||
|
CONF_CODE_INTERPRETER: False,
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
|
CONF_CODE_INTERPRETER,
|
||||||
CONF_WEB_SEARCH,
|
CONF_WEB_SEARCH,
|
||||||
CONF_WEB_SEARCH_CITY,
|
CONF_WEB_SEARCH_CITY,
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
@ -30,6 +31,7 @@ from homeassistant.helpers import intent
|
|||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
|
create_code_interpreter_item,
|
||||||
create_function_tool_call_item,
|
create_function_tool_call_item,
|
||||||
create_message_item,
|
create_message_item,
|
||||||
create_reasoning_item,
|
create_reasoning_item,
|
||||||
@ -485,3 +487,49 @@ async def test_web_search(
|
|||||||
]
|
]
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||||
|
|
||||||
|
|
||||||
|
async def test_code_interpreter(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
mock_create_stream,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
) -> None:
|
||||||
|
"""Test code_interpreter tool."""
|
||||||
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
hass.config_entries.async_update_subentry(
|
||||||
|
mock_config_entry,
|
||||||
|
subentry,
|
||||||
|
data={
|
||||||
|
**subentry.data,
|
||||||
|
CONF_CODE_INTERPRETER: True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.config_entries.async_reload(mock_config_entry.entry_id)
|
||||||
|
|
||||||
|
message = "I’ve calculated it with Python: the square root of 55555 is approximately 235.70108188126758."
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
(
|
||||||
|
*create_code_interpreter_item(
|
||||||
|
id="ci_A",
|
||||||
|
code=["import", " math", "\n", "math", ".sqrt", "(", "555", "55", ")"],
|
||||||
|
output_index=0,
|
||||||
|
),
|
||||||
|
*create_message_item(id="msg_A", text=message, output_index=1),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please use the python tool to calculate square root of 55555",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id="conversation.openai_conversation",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_create_stream.mock_calls[0][2]["tools"] == [
|
||||||
|
{"type": "code_interpreter", "container": {"type": "auto"}}
|
||||||
|
]
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||||
|
Loading…
x
Reference in New Issue
Block a user