mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 09:47:13 +00:00
Add Web search to OpenAI Conversation integration (#141426)
* Add Web search to OpenAI Conversation integration * Limit search for gpt-4o models * Add more tests
This commit is contained in:
parent
8db91623ec
commit
c974285490
@ -2,22 +2,31 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
|
from homeassistant.components.zone import ENTITY_ID_HOME
|
||||||
from homeassistant.config_entries import (
|
from homeassistant.config_entries import (
|
||||||
ConfigEntry,
|
ConfigEntry,
|
||||||
ConfigFlow,
|
ConfigFlow,
|
||||||
ConfigFlowResult,
|
ConfigFlowResult,
|
||||||
OptionsFlow,
|
OptionsFlow,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
|
from homeassistant.const import (
|
||||||
|
ATTR_LATITUDE,
|
||||||
|
ATTR_LONGITUDE,
|
||||||
|
CONF_API_KEY,
|
||||||
|
CONF_LLM_HASS_API,
|
||||||
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
|
from homeassistant.helpers.httpx_client import get_async_client
|
||||||
from homeassistant.helpers.selector import (
|
from homeassistant.helpers.selector import (
|
||||||
NumberSelector,
|
NumberSelector,
|
||||||
NumberSelectorConfig,
|
NumberSelectorConfig,
|
||||||
@ -37,12 +46,22 @@ from .const import (
|
|||||||
CONF_RECOMMENDED,
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
CONF_WEB_SEARCH,
|
||||||
|
CONF_WEB_SEARCH_CITY,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
CONF_WEB_SEARCH_COUNTRY,
|
||||||
|
CONF_WEB_SEARCH_REGION,
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_REASONING_EFFORT,
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
|
RECOMMENDED_WEB_SEARCH,
|
||||||
|
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
RECOMMENDED_WEB_SEARCH_USER_LOCATION,
|
||||||
UNSUPPORTED_MODELS,
|
UNSUPPORTED_MODELS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,7 +85,9 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
|
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
|
||||||
"""
|
"""
|
||||||
client = openai.AsyncOpenAI(api_key=data[CONF_API_KEY])
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=data[CONF_API_KEY], http_client=get_async_client(hass)
|
||||||
|
)
|
||||||
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
|
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +158,16 @@ class OpenAIOptionsFlow(OptionsFlow):
|
|||||||
|
|
||||||
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
|
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
|
||||||
errors[CONF_CHAT_MODEL] = "model_not_supported"
|
errors[CONF_CHAT_MODEL] = "model_not_supported"
|
||||||
else:
|
|
||||||
|
if user_input.get(CONF_WEB_SEARCH):
|
||||||
|
if not user_input.get(
|
||||||
|
CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL
|
||||||
|
).startswith("gpt-4o"):
|
||||||
|
errors[CONF_WEB_SEARCH] = "web_search_not_supported"
|
||||||
|
elif user_input.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||||
|
user_input.update(await self.get_location_data())
|
||||||
|
|
||||||
|
if not errors:
|
||||||
return self.async_create_entry(title="", data=user_input)
|
return self.async_create_entry(title="", data=user_input)
|
||||||
else:
|
else:
|
||||||
# Re-render the options again, now with the recommended options shown/hidden
|
# Re-render the options again, now with the recommended options shown/hidden
|
||||||
@ -156,6 +186,59 @@ class OpenAIOptionsFlow(OptionsFlow):
|
|||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_location_data(self) -> dict[str, str]:
|
||||||
|
"""Get approximate location data of the user."""
|
||||||
|
location_data: dict[str, str] = {}
|
||||||
|
zone_home = self.hass.states.get(ENTITY_ID_HOME)
|
||||||
|
if zone_home is not None:
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
api_key=self.config_entry.data[CONF_API_KEY],
|
||||||
|
http_client=get_async_client(self.hass),
|
||||||
|
)
|
||||||
|
location_schema = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional(
|
||||||
|
CONF_WEB_SEARCH_CITY,
|
||||||
|
description="Free text input for the city, e.g. `San Francisco`",
|
||||||
|
): str,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_WEB_SEARCH_REGION,
|
||||||
|
description="Free text input for the region, e.g. `California`",
|
||||||
|
): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=RECOMMENDED_CHAT_MODEL,
|
||||||
|
input=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Where are the following coordinates located: "
|
||||||
|
f"({zone_home.attributes[ATTR_LATITUDE]},"
|
||||||
|
f" {zone_home.attributes[ATTR_LONGITUDE]})?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
text={
|
||||||
|
"format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"name": "approximate_location",
|
||||||
|
"description": "Approximate location data of the user "
|
||||||
|
"for refined web search results",
|
||||||
|
"schema": convert(location_schema),
|
||||||
|
"strict": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
store=False,
|
||||||
|
)
|
||||||
|
location_data = location_schema(json.loads(response.output_text) or {})
|
||||||
|
|
||||||
|
if self.hass.config.country:
|
||||||
|
location_data[CONF_WEB_SEARCH_COUNTRY] = self.hass.config.country
|
||||||
|
location_data[CONF_WEB_SEARCH_TIMEZONE] = self.hass.config.time_zone
|
||||||
|
|
||||||
|
_LOGGER.debug("Location data: %s", location_data)
|
||||||
|
|
||||||
|
return location_data
|
||||||
|
|
||||||
|
|
||||||
def openai_config_option_schema(
|
def openai_config_option_schema(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -227,10 +310,35 @@ def openai_config_option_schema(
|
|||||||
): SelectSelector(
|
): SelectSelector(
|
||||||
SelectSelectorConfig(
|
SelectSelectorConfig(
|
||||||
options=["low", "medium", "high"],
|
options=["low", "medium", "high"],
|
||||||
translation_key="reasoning_effort",
|
translation_key=CONF_REASONING_EFFORT,
|
||||||
mode=SelectSelectorMode.DROPDOWN,
|
mode=SelectSelectorMode.DROPDOWN,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_WEB_SEARCH,
|
||||||
|
description={"suggested_value": options.get(CONF_WEB_SEARCH)},
|
||||||
|
default=RECOMMENDED_WEB_SEARCH,
|
||||||
|
): bool,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
description={
|
||||||
|
"suggested_value": options.get(CONF_WEB_SEARCH_CONTEXT_SIZE)
|
||||||
|
},
|
||||||
|
default=RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
): SelectSelector(
|
||||||
|
SelectSelectorConfig(
|
||||||
|
options=["low", "medium", "high"],
|
||||||
|
translation_key=CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
mode=SelectSelectorMode.DROPDOWN,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
|
description={
|
||||||
|
"suggested_value": options.get(CONF_WEB_SEARCH_USER_LOCATION)
|
||||||
|
},
|
||||||
|
default=RECOMMENDED_WEB_SEARCH_USER_LOCATION,
|
||||||
|
): bool,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return schema
|
return schema
|
||||||
|
@ -14,11 +14,21 @@ CONF_REASONING_EFFORT = "reasoning_effort"
|
|||||||
CONF_RECOMMENDED = "recommended"
|
CONF_RECOMMENDED = "recommended"
|
||||||
CONF_TEMPERATURE = "temperature"
|
CONF_TEMPERATURE = "temperature"
|
||||||
CONF_TOP_P = "top_p"
|
CONF_TOP_P = "top_p"
|
||||||
|
CONF_WEB_SEARCH = "web_search"
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION = "user_location"
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE = "search_context_size"
|
||||||
|
CONF_WEB_SEARCH_CITY = "city"
|
||||||
|
CONF_WEB_SEARCH_REGION = "region"
|
||||||
|
CONF_WEB_SEARCH_COUNTRY = "country"
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE = "timezone"
|
||||||
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
|
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
|
||||||
RECOMMENDED_MAX_TOKENS = 150
|
RECOMMENDED_MAX_TOKENS = 150
|
||||||
RECOMMENDED_REASONING_EFFORT = "low"
|
RECOMMENDED_REASONING_EFFORT = "low"
|
||||||
RECOMMENDED_TEMPERATURE = 1.0
|
RECOMMENDED_TEMPERATURE = 1.0
|
||||||
RECOMMENDED_TOP_P = 1.0
|
RECOMMENDED_TOP_P = 1.0
|
||||||
|
RECOMMENDED_WEB_SEARCH = False
|
||||||
|
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
|
||||||
|
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
|
||||||
|
|
||||||
UNSUPPORTED_MODELS: list[str] = [
|
UNSUPPORTED_MODELS: list[str] = [
|
||||||
"o1-mini",
|
"o1-mini",
|
||||||
|
@ -23,8 +23,10 @@ from openai.types.responses import (
|
|||||||
ResponseStreamEvent,
|
ResponseStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ToolParam,
|
ToolParam,
|
||||||
|
WebSearchToolParam,
|
||||||
)
|
)
|
||||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||||
|
from openai.types.responses.web_search_tool_param import UserLocation
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
@ -43,6 +45,13 @@ from .const import (
|
|||||||
CONF_REASONING_EFFORT,
|
CONF_REASONING_EFFORT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
CONF_WEB_SEARCH,
|
||||||
|
CONF_WEB_SEARCH_CITY,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
CONF_WEB_SEARCH_COUNTRY,
|
||||||
|
CONF_WEB_SEARCH_REGION,
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
@ -50,6 +59,7 @@ from .const import (
|
|||||||
RECOMMENDED_REASONING_EFFORT,
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
|
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
@ -265,6 +275,25 @@ class OpenAIConversationEntity(
|
|||||||
for tool in chat_log.llm_api.tools
|
for tool in chat_log.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if options.get(CONF_WEB_SEARCH):
|
||||||
|
web_search = WebSearchToolParam(
|
||||||
|
type="web_search_preview",
|
||||||
|
search_context_size=options.get(
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||||
|
web_search["user_location"] = UserLocation(
|
||||||
|
type="approximate",
|
||||||
|
city=options.get(CONF_WEB_SEARCH_CITY, ""),
|
||||||
|
region=options.get(CONF_WEB_SEARCH_REGION, ""),
|
||||||
|
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||||
|
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||||
|
)
|
||||||
|
if tools is None:
|
||||||
|
tools = []
|
||||||
|
tools.append(web_search)
|
||||||
|
|
||||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
messages = [
|
messages = [
|
||||||
m
|
m
|
||||||
|
@ -24,16 +24,23 @@
|
|||||||
"top_p": "Top P",
|
"top_p": "Top P",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||||
"recommended": "Recommended model settings",
|
"recommended": "Recommended model settings",
|
||||||
"reasoning_effort": "Reasoning effort"
|
"reasoning_effort": "Reasoning effort",
|
||||||
|
"web_search": "Enable web search",
|
||||||
|
"search_context_size": "Search context size",
|
||||||
|
"user_location": "Include home location"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||||
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
|
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)",
|
||||||
|
"web_search": "Allow the model to search the web for the latest information before generating a response",
|
||||||
|
"search_context_size": "High level guidance for the amount of context window space to use for the search",
|
||||||
|
"user_location": "Refine search results based on geography"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"model_not_supported": "This model is not supported, please select a different model"
|
"model_not_supported": "This model is not supported, please select a different model",
|
||||||
|
"web_search_not_supported": "Web search is only supported for gpt-4o and gpt-4o-mini models"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"selector": {
|
"selector": {
|
||||||
@ -43,6 +50,13 @@
|
|||||||
"medium": "Medium",
|
"medium": "Medium",
|
||||||
"high": "High"
|
"high": "High"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"search_context_size": {
|
||||||
|
"options": {
|
||||||
|
"low": "Low",
|
||||||
|
"medium": "Medium",
|
||||||
|
"high": "High"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"services": {
|
"services": {
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
"""Test the OpenAI Conversation config flow."""
|
"""Test the OpenAI Conversation config flow."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from httpx import Response
|
import httpx
|
||||||
from openai import APIConnectionError, AuthenticationError, BadRequestError
|
from openai import APIConnectionError, AuthenticationError, BadRequestError
|
||||||
|
from openai.types.responses import Response, ResponseOutputMessage, ResponseOutputText
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
@ -16,6 +17,13 @@ from homeassistant.components.openai_conversation.const import (
|
|||||||
CONF_RECOMMENDED,
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
CONF_WEB_SEARCH,
|
||||||
|
CONF_WEB_SEARCH_CITY,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
CONF_WEB_SEARCH_COUNTRY,
|
||||||
|
CONF_WEB_SEARCH_REGION,
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
@ -117,13 +125,17 @@ async def test_options_unsupported_model(
|
|||||||
(APIConnectionError(request=None), "cannot_connect"),
|
(APIConnectionError(request=None), "cannot_connect"),
|
||||||
(
|
(
|
||||||
AuthenticationError(
|
AuthenticationError(
|
||||||
response=Response(status_code=None, request=""), body=None, message=None
|
response=httpx.Response(status_code=None, request=""),
|
||||||
|
body=None,
|
||||||
|
message=None,
|
||||||
),
|
),
|
||||||
"invalid_auth",
|
"invalid_auth",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
BadRequestError(
|
BadRequestError(
|
||||||
response=Response(status_code=None, request=""), body=None, message=None
|
response=httpx.Response(status_code=None, request=""),
|
||||||
|
body=None,
|
||||||
|
message=None,
|
||||||
),
|
),
|
||||||
"unknown",
|
"unknown",
|
||||||
),
|
),
|
||||||
@ -172,6 +184,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||||
|
CONF_WEB_SEARCH: False,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -183,6 +198,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||||
|
CONF_WEB_SEARCH: False,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
CONF_RECOMMENDED: True,
|
CONF_RECOMMENDED: True,
|
||||||
@ -225,3 +243,105 @@ async def test_options_switching(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert options["data"] == expected_options
|
assert options["data"] == expected_options
|
||||||
|
|
||||||
|
|
||||||
|
async def test_options_web_search_user_location(
|
||||||
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test fetching user location."""
|
||||||
|
options_flow = await hass.config_entries.options.async_init(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
hass.config.country = "US"
|
||||||
|
hass.config.time_zone = "America/Los_Angeles"
|
||||||
|
hass.states.async_set(
|
||||||
|
"zone.home", "0", {"latitude": 37.7749, "longitude": -122.4194}
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"openai.resources.responses.AsyncResponses.create",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_create:
|
||||||
|
mock_create.return_value = Response(
|
||||||
|
object="response",
|
||||||
|
id="resp_A",
|
||||||
|
created_at=1700000000,
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
parallel_tool_calls=True,
|
||||||
|
tool_choice="auto",
|
||||||
|
tools=[],
|
||||||
|
output=[
|
||||||
|
ResponseOutputMessage(
|
||||||
|
type="message",
|
||||||
|
id="msg_A",
|
||||||
|
content=[
|
||||||
|
ResponseOutputText(
|
||||||
|
type="output_text",
|
||||||
|
text='{"city": "San Francisco", "region": "California"}',
|
||||||
|
annotations=[],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
role="assistant",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
options = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 1.0,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||||
|
CONF_WEB_SEARCH: True,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert (
|
||||||
|
mock_create.call_args.kwargs["input"][0]["content"] == "Where are the following"
|
||||||
|
" coordinates located: (37.7749, -122.4194)?"
|
||||||
|
)
|
||||||
|
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert options["data"] == {
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 1.0,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||||
|
CONF_WEB_SEARCH: True,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||||
|
CONF_WEB_SEARCH_CITY: "San Francisco",
|
||||||
|
CONF_WEB_SEARCH_REGION: "California",
|
||||||
|
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_options_web_search_unsupported_model(
|
||||||
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test the options form giving error about web search not being available."""
|
||||||
|
options_flow = await hass.config_entries.options.async_init(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_CHAT_MODEL: "o1-pro",
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
CONF_WEB_SEARCH: True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["errors"] == {"web_search": "web_search_not_supported"}
|
||||||
|
@ -18,6 +18,7 @@ from openai.types.responses import (
|
|||||||
ResponseFunctionCallArgumentsDeltaEvent,
|
ResponseFunctionCallArgumentsDeltaEvent,
|
||||||
ResponseFunctionCallArgumentsDoneEvent,
|
ResponseFunctionCallArgumentsDoneEvent,
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
|
ResponseFunctionWebSearch,
|
||||||
ResponseIncompleteEvent,
|
ResponseIncompleteEvent,
|
||||||
ResponseInProgressEvent,
|
ResponseInProgressEvent,
|
||||||
ResponseOutputItemAddedEvent,
|
ResponseOutputItemAddedEvent,
|
||||||
@ -29,6 +30,9 @@ from openai.types.responses import (
|
|||||||
ResponseTextConfig,
|
ResponseTextConfig,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
|
ResponseWebSearchCallCompletedEvent,
|
||||||
|
ResponseWebSearchCallInProgressEvent,
|
||||||
|
ResponseWebSearchCallSearchingEvent,
|
||||||
)
|
)
|
||||||
from openai.types.responses.response import IncompleteDetails
|
from openai.types.responses.response import IncompleteDetails
|
||||||
import pytest
|
import pytest
|
||||||
@ -36,6 +40,15 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
|
from homeassistant.components.openai_conversation.const import (
|
||||||
|
CONF_WEB_SEARCH,
|
||||||
|
CONF_WEB_SEARCH_CITY,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
CONF_WEB_SEARCH_COUNTRY,
|
||||||
|
CONF_WEB_SEARCH_REGION,
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
@ -225,7 +238,6 @@ async def test_incomplete_response(
|
|||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
mock_create_stream: AsyncMock,
|
mock_create_stream: AsyncMock,
|
||||||
mock_chat_log: MockChatLog, # noqa: F811
|
|
||||||
reason: str,
|
reason: str,
|
||||||
message: str,
|
message: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -301,7 +313,6 @@ async def test_failed_response(
|
|||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
mock_create_stream: AsyncMock,
|
mock_create_stream: AsyncMock,
|
||||||
mock_chat_log: MockChatLog, # noqa: F811
|
|
||||||
error: ResponseError | ResponseErrorEvent,
|
error: ResponseError | ResponseErrorEvent,
|
||||||
message: str,
|
message: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -491,6 +502,41 @@ def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEven
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a web search call item."""
|
||||||
|
return [
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
item=ResponseFunctionWebSearch(
|
||||||
|
id=id, status="in_progress", type="web_search_call"
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.output_item.added",
|
||||||
|
),
|
||||||
|
ResponseWebSearchCallInProgressEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.web_search_call.in_progress",
|
||||||
|
),
|
||||||
|
ResponseWebSearchCallSearchingEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.web_search_call.searching",
|
||||||
|
),
|
||||||
|
ResponseWebSearchCallCompletedEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.web_search_call.completed",
|
||||||
|
),
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ResponseFunctionWebSearch(
|
||||||
|
id=id, status="completed", type="web_search_call"
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.output_item.done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def test_function_call(
|
async def test_function_call(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
@ -581,7 +627,6 @@ async def test_function_call_invalid(
|
|||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
mock_create_stream: AsyncMock,
|
mock_create_stream: AsyncMock,
|
||||||
mock_chat_log: MockChatLog, # noqa: F811
|
|
||||||
description: str,
|
description: str,
|
||||||
messages: tuple[ResponseStreamEvent],
|
messages: tuple[ResponseStreamEvent],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -633,3 +678,60 @@ async def test_assist_api_tools_conversion(
|
|||||||
|
|
||||||
tools = mock_create_stream.mock_calls[0][2]["tools"]
|
tools = mock_create_stream.mock_calls[0][2]["tools"]
|
||||||
assert tools
|
assert tools
|
||||||
|
|
||||||
|
|
||||||
|
async def test_web_search(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
mock_create_stream,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
) -> None:
|
||||||
|
"""Test web_search_tool."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={
|
||||||
|
**mock_config_entry.options,
|
||||||
|
CONF_WEB_SEARCH: True,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||||
|
CONF_WEB_SEARCH_CITY: "San Francisco",
|
||||||
|
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||||
|
CONF_WEB_SEARCH_REGION: "California",
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.config_entries.async_reload(mock_config_entry.entry_id)
|
||||||
|
|
||||||
|
message = "Home Assistant now supports ChatGPT Search in Assist"
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
# Initial conversation
|
||||||
|
(
|
||||||
|
*create_web_search_item(id="ws_A", output_index=0),
|
||||||
|
*create_message_item(id="msg_A", text=message, output_index=1),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"What's on the latest news?",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id="conversation.openai",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_create_stream.mock_calls[0][2]["tools"] == [
|
||||||
|
{
|
||||||
|
"type": "web_search_preview",
|
||||||
|
"search_context_size": "low",
|
||||||
|
"user_location": {
|
||||||
|
"type": "approximate",
|
||||||
|
"city": "San Francisco",
|
||||||
|
"region": "California",
|
||||||
|
"country": "US",
|
||||||
|
"timezone": "America/Los_Angeles",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||||
|
Loading…
x
Reference in New Issue
Block a user