Add Web search to OpenAI Conversation integration (#141426)

* Add Web search to OpenAI Conversation integration

* Limit search for gpt-4o models

* Add more tests
This commit is contained in:
Denis Shulyaka 2025-03-26 16:36:05 +03:00 committed by GitHub
parent 8db91623ec
commit c974285490
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 397 additions and 14 deletions

View File

@ -2,22 +2,31 @@
from __future__ import annotations
import json
import logging
from types import MappingProxyType
from typing import Any
import openai
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components.zone import ENTITY_ID_HOME
from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.const import (
ATTR_LATITUDE,
ATTR_LONGITUDE,
CONF_API_KEY,
CONF_LLM_HASS_API,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
@ -37,12 +46,22 @@ from .const import (
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
RECOMMENDED_WEB_SEARCH,
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
RECOMMENDED_WEB_SEARCH_USER_LOCATION,
UNSUPPORTED_MODELS,
)
@ -66,7 +85,9 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
"""
client = openai.AsyncOpenAI(api_key=data[CONF_API_KEY])
client = openai.AsyncOpenAI(
api_key=data[CONF_API_KEY], http_client=get_async_client(hass)
)
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
@ -137,7 +158,16 @@ class OpenAIOptionsFlow(OptionsFlow):
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
errors[CONF_CHAT_MODEL] = "model_not_supported"
else:
if user_input.get(CONF_WEB_SEARCH):
if not user_input.get(
CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL
).startswith("gpt-4o"):
errors[CONF_WEB_SEARCH] = "web_search_not_supported"
elif user_input.get(CONF_WEB_SEARCH_USER_LOCATION):
user_input.update(await self.get_location_data())
if not errors:
return self.async_create_entry(title="", data=user_input)
else:
# Re-render the options again, now with the recommended options shown/hidden
@ -156,6 +186,59 @@ class OpenAIOptionsFlow(OptionsFlow):
errors=errors,
)
async def get_location_data(self) -> dict[str, str]:
"""Get approximate location data of the user."""
location_data: dict[str, str] = {}
zone_home = self.hass.states.get(ENTITY_ID_HOME)
if zone_home is not None:
client = openai.AsyncOpenAI(
api_key=self.config_entry.data[CONF_API_KEY],
http_client=get_async_client(self.hass),
)
location_schema = vol.Schema(
{
vol.Optional(
CONF_WEB_SEARCH_CITY,
description="Free text input for the city, e.g. `San Francisco`",
): str,
vol.Optional(
CONF_WEB_SEARCH_REGION,
description="Free text input for the region, e.g. `California`",
): str,
}
)
response = await client.responses.create(
model=RECOMMENDED_CHAT_MODEL,
input=[
{
"role": "system",
"content": "Where are the following coordinates located: "
f"({zone_home.attributes[ATTR_LATITUDE]},"
f" {zone_home.attributes[ATTR_LONGITUDE]})?",
}
],
text={
"format": {
"type": "json_schema",
"name": "approximate_location",
"description": "Approximate location data of the user "
"for refined web search results",
"schema": convert(location_schema),
"strict": False,
}
},
store=False,
)
location_data = location_schema(json.loads(response.output_text) or {})
if self.hass.config.country:
location_data[CONF_WEB_SEARCH_COUNTRY] = self.hass.config.country
location_data[CONF_WEB_SEARCH_TIMEZONE] = self.hass.config.time_zone
_LOGGER.debug("Location data: %s", location_data)
return location_data
def openai_config_option_schema(
hass: HomeAssistant,
@ -227,10 +310,35 @@ def openai_config_option_schema(
): SelectSelector(
SelectSelectorConfig(
options=["low", "medium", "high"],
translation_key="reasoning_effort",
translation_key=CONF_REASONING_EFFORT,
mode=SelectSelectorMode.DROPDOWN,
)
),
vol.Optional(
CONF_WEB_SEARCH,
description={"suggested_value": options.get(CONF_WEB_SEARCH)},
default=RECOMMENDED_WEB_SEARCH,
): bool,
vol.Optional(
CONF_WEB_SEARCH_CONTEXT_SIZE,
description={
"suggested_value": options.get(CONF_WEB_SEARCH_CONTEXT_SIZE)
},
default=RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
): SelectSelector(
SelectSelectorConfig(
options=["low", "medium", "high"],
translation_key=CONF_WEB_SEARCH_CONTEXT_SIZE,
mode=SelectSelectorMode.DROPDOWN,
)
),
vol.Optional(
CONF_WEB_SEARCH_USER_LOCATION,
description={
"suggested_value": options.get(CONF_WEB_SEARCH_USER_LOCATION)
},
default=RECOMMENDED_WEB_SEARCH_USER_LOCATION,
): bool,
}
)
return schema

View File

@ -14,11 +14,21 @@ CONF_REASONING_EFFORT = "reasoning_effort"
CONF_RECOMMENDED = "recommended"
CONF_TEMPERATURE = "temperature"
CONF_TOP_P = "top_p"
CONF_WEB_SEARCH = "web_search"
CONF_WEB_SEARCH_USER_LOCATION = "user_location"
CONF_WEB_SEARCH_CONTEXT_SIZE = "search_context_size"
CONF_WEB_SEARCH_CITY = "city"
CONF_WEB_SEARCH_REGION = "region"
CONF_WEB_SEARCH_COUNTRY = "country"
CONF_WEB_SEARCH_TIMEZONE = "timezone"
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
RECOMMENDED_MAX_TOKENS = 150
RECOMMENDED_REASONING_EFFORT = "low"
RECOMMENDED_TEMPERATURE = 1.0
RECOMMENDED_TOP_P = 1.0
RECOMMENDED_WEB_SEARCH = False
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
UNSUPPORTED_MODELS: list[str] = [
"o1-mini",

View File

@ -23,8 +23,10 @@ from openai.types.responses import (
ResponseStreamEvent,
ResponseTextDeltaEvent,
ToolParam,
WebSearchToolParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.web_search_tool_param import UserLocation
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
@ -43,6 +45,13 @@ from .const import (
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
@ -50,6 +59,7 @@ from .const import (
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
)
# Max number of back and forth with the LLM to generate a response
@ -265,6 +275,25 @@ class OpenAIConversationEntity(
for tool in chat_log.llm_api.tools
]
if options.get(CONF_WEB_SEARCH):
web_search = WebSearchToolParam(
type="web_search_preview",
search_context_size=options.get(
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
),
)
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
web_search["user_location"] = UserLocation(
type="approximate",
city=options.get(CONF_WEB_SEARCH_CITY, ""),
region=options.get(CONF_WEB_SEARCH_REGION, ""),
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
)
if tools is None:
tools = []
tools.append(web_search)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
m

View File

@ -24,16 +24,23 @@
"top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings",
"reasoning_effort": "Reasoning effort"
"reasoning_effort": "Reasoning effort",
"web_search": "Enable web search",
"search_context_size": "Search context size",
"user_location": "Include home location"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template.",
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)",
"web_search": "Allow the model to search the web for the latest information before generating a response",
"search_context_size": "High level guidance for the amount of context window space to use for the search",
"user_location": "Refine search results based on geography"
}
}
},
"error": {
"model_not_supported": "This model is not supported, please select a different model"
"model_not_supported": "This model is not supported, please select a different model",
"web_search_not_supported": "Web search is only supported for gpt-4o and gpt-4o-mini models"
}
},
"selector": {
@ -43,6 +50,13 @@
"medium": "Medium",
"high": "High"
}
},
"search_context_size": {
"options": {
"low": "Low",
"medium": "Medium",
"high": "High"
}
}
},
"services": {

View File

@ -1,9 +1,10 @@
"""Test the OpenAI Conversation config flow."""
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from httpx import Response
import httpx
from openai import APIConnectionError, AuthenticationError, BadRequestError
from openai.types.responses import Response, ResponseOutputMessage, ResponseOutputText
import pytest
from homeassistant import config_entries
@ -16,6 +17,13 @@ from homeassistant.components.openai_conversation.const import (
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
@ -117,13 +125,17 @@ async def test_options_unsupported_model(
(APIConnectionError(request=None), "cannot_connect"),
(
AuthenticationError(
response=Response(status_code=None, request=""), body=None, message=None
response=httpx.Response(status_code=None, request=""),
body=None,
message=None,
),
"invalid_auth",
),
(
BadRequestError(
response=Response(status_code=None, request=""), body=None, message=None
response=httpx.Response(status_code=None, request=""),
body=None,
message=None,
),
"unknown",
),
@ -172,6 +184,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
CONF_WEB_SEARCH_USER_LOCATION: False,
},
),
(
@ -183,6 +198,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
CONF_WEB_SEARCH_USER_LOCATION: False,
},
{
CONF_RECOMMENDED: True,
@ -225,3 +243,105 @@ async def test_options_switching(
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options
async def test_options_web_search_user_location(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test fetching user location."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
hass.config.country = "US"
hass.config.time_zone = "America/Los_Angeles"
hass.states.async_set(
"zone.home", "0", {"latitude": 37.7749, "longitude": -122.4194}
)
with patch(
"openai.resources.responses.AsyncResponses.create",
new_callable=AsyncMock,
) as mock_create:
mock_create.return_value = Response(
object="response",
id="resp_A",
created_at=1700000000,
model="gpt-4o-mini",
parallel_tool_calls=True,
tool_choice="auto",
tools=[],
output=[
ResponseOutputMessage(
type="message",
id="msg_A",
content=[
ResponseOutputText(
type="output_text",
text='{"city": "San Francisco", "region": "California"}',
annotations=[],
)
],
role="assistant",
status="completed",
)
],
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 1.0,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
CONF_WEB_SEARCH_USER_LOCATION: True,
},
)
await hass.async_block_till_done()
assert (
mock_create.call_args.kwargs["input"][0]["content"] == "Where are the following"
" coordinates located: (37.7749, -122.4194)?"
)
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == {
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 1.0,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
CONF_WEB_SEARCH_USER_LOCATION: True,
CONF_WEB_SEARCH_CITY: "San Francisco",
CONF_WEB_SEARCH_REGION: "California",
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
}
async def test_options_web_search_unsupported_model(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options form giving error about web search not being available."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
result = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_CHAT_MODEL: "o1-pro",
CONF_LLM_HASS_API: "assist",
CONF_WEB_SEARCH: True,
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"web_search": "web_search_not_supported"}

View File

@ -18,6 +18,7 @@ from openai.types.responses import (
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
@ -29,6 +30,9 @@ from openai.types.responses import (
ResponseTextConfig,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses.response import IncompleteDetails
import pytest
@ -36,6 +40,15 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.openai_conversation.const import (
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
@ -225,7 +238,6 @@ async def test_incomplete_response(
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog, # noqa: F811
reason: str,
message: str,
) -> None:
@ -301,7 +313,6 @@ async def test_failed_response(
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog, # noqa: F811
error: ResponseError | ResponseErrorEvent,
message: str,
) -> None:
@ -491,6 +502,41 @@ def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEven
]
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a web search call item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseFunctionWebSearch(
id=id, status="in_progress", type="web_search_call"
),
output_index=output_index,
type="response.output_item.added",
),
ResponseWebSearchCallInProgressEvent(
item_id=id,
output_index=output_index,
type="response.web_search_call.in_progress",
),
ResponseWebSearchCallSearchingEvent(
item_id=id,
output_index=output_index,
type="response.web_search_call.searching",
),
ResponseWebSearchCallCompletedEvent(
item_id=id,
output_index=output_index,
type="response.web_search_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseFunctionWebSearch(
id=id, status="completed", type="web_search_call"
),
output_index=output_index,
type="response.output_item.done",
),
]
async def test_function_call(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
@ -581,7 +627,6 @@ async def test_function_call_invalid(
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog, # noqa: F811
description: str,
messages: tuple[ResponseStreamEvent],
) -> None:
@ -633,3 +678,60 @@ async def test_assist_api_tools_conversion(
tools = mock_create_stream.mock_calls[0][2]["tools"]
assert tools
async def test_web_search(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
mock_create_stream,
mock_chat_log: MockChatLog, # noqa: F811
) -> None:
"""Test web_search_tool."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: True,
CONF_WEB_SEARCH_CITY: "San Francisco",
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_REGION: "California",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
},
)
await hass.config_entries.async_reload(mock_config_entry.entry_id)
message = "Home Assistant now supports ChatGPT Search in Assist"
mock_create_stream.return_value = [
# Initial conversation
(
*create_web_search_item(id="ws_A", output_index=0),
*create_message_item(id="msg_A", text=message, output_index=1),
)
]
result = await conversation.async_converse(
hass,
"What's on the latest news?",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai",
)
assert mock_create_stream.mock_calls[0][2]["tools"] == [
{
"type": "web_search_preview",
"search_context_size": "low",
"user_location": {
"type": "approximate",
"city": "San Francisco",
"region": "California",
"country": "US",
"timezone": "America/Los_Angeles",
},
}
]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech["plain"]["speech"] == message, result.response.speech