mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Add Web search to OpenAI Conversation integration (#141426)
* Add Web search to OpenAI Conversation integration * Limit search for gpt-4o models * Add more tests
This commit is contained in:
parent
8db91623ec
commit
c974285490
@ -2,22 +2,31 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components.zone import ENTITY_ID_HOME
|
||||
from homeassistant.config_entries import (
|
||||
ConfigEntry,
|
||||
ConfigFlow,
|
||||
ConfigFlowResult,
|
||||
OptionsFlow,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
|
||||
from homeassistant.const import (
|
||||
ATTR_LATITUDE,
|
||||
ATTR_LONGITUDE,
|
||||
CONF_API_KEY,
|
||||
CONF_LLM_HASS_API,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.httpx_client import get_async_client
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
@ -37,12 +46,22 @@ from .const import (
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_WEB_SEARCH,
|
||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||
RECOMMENDED_WEB_SEARCH_USER_LOCATION,
|
||||
UNSUPPORTED_MODELS,
|
||||
)
|
||||
|
||||
@ -66,7 +85,9 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||
|
||||
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
|
||||
"""
|
||||
client = openai.AsyncOpenAI(api_key=data[CONF_API_KEY])
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=data[CONF_API_KEY], http_client=get_async_client(hass)
|
||||
)
|
||||
await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
|
||||
|
||||
|
||||
@ -137,7 +158,16 @@ class OpenAIOptionsFlow(OptionsFlow):
|
||||
|
||||
if user_input.get(CONF_CHAT_MODEL) in UNSUPPORTED_MODELS:
|
||||
errors[CONF_CHAT_MODEL] = "model_not_supported"
|
||||
else:
|
||||
|
||||
if user_input.get(CONF_WEB_SEARCH):
|
||||
if not user_input.get(
|
||||
CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL
|
||||
).startswith("gpt-4o"):
|
||||
errors[CONF_WEB_SEARCH] = "web_search_not_supported"
|
||||
elif user_input.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||
user_input.update(await self.get_location_data())
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
else:
|
||||
# Re-render the options again, now with the recommended options shown/hidden
|
||||
@ -156,6 +186,59 @@ class OpenAIOptionsFlow(OptionsFlow):
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def get_location_data(self) -> dict[str, str]:
|
||||
"""Get approximate location data of the user."""
|
||||
location_data: dict[str, str] = {}
|
||||
zone_home = self.hass.states.get(ENTITY_ID_HOME)
|
||||
if zone_home is not None:
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=self.config_entry.data[CONF_API_KEY],
|
||||
http_client=get_async_client(self.hass),
|
||||
)
|
||||
location_schema = vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
description="Free text input for the city, e.g. `San Francisco`",
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
description="Free text input for the region, e.g. `California`",
|
||||
): str,
|
||||
}
|
||||
)
|
||||
response = await client.responses.create(
|
||||
model=RECOMMENDED_CHAT_MODEL,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Where are the following coordinates located: "
|
||||
f"({zone_home.attributes[ATTR_LATITUDE]},"
|
||||
f" {zone_home.attributes[ATTR_LONGITUDE]})?",
|
||||
}
|
||||
],
|
||||
text={
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"name": "approximate_location",
|
||||
"description": "Approximate location data of the user "
|
||||
"for refined web search results",
|
||||
"schema": convert(location_schema),
|
||||
"strict": False,
|
||||
}
|
||||
},
|
||||
store=False,
|
||||
)
|
||||
location_data = location_schema(json.loads(response.output_text) or {})
|
||||
|
||||
if self.hass.config.country:
|
||||
location_data[CONF_WEB_SEARCH_COUNTRY] = self.hass.config.country
|
||||
location_data[CONF_WEB_SEARCH_TIMEZONE] = self.hass.config.time_zone
|
||||
|
||||
_LOGGER.debug("Location data: %s", location_data)
|
||||
|
||||
return location_data
|
||||
|
||||
|
||||
def openai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
@ -227,10 +310,35 @@ def openai_config_option_schema(
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=["low", "medium", "high"],
|
||||
translation_key="reasoning_effort",
|
||||
translation_key=CONF_REASONING_EFFORT,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH,
|
||||
description={"suggested_value": options.get(CONF_WEB_SEARCH)},
|
||||
default=RECOMMENDED_WEB_SEARCH,
|
||||
): bool,
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_WEB_SEARCH_CONTEXT_SIZE)
|
||||
},
|
||||
default=RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=["low", "medium", "high"],
|
||||
translation_key=CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
mode=SelectSelectorMode.DROPDOWN,
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_WEB_SEARCH_USER_LOCATION)
|
||||
},
|
||||
default=RECOMMENDED_WEB_SEARCH_USER_LOCATION,
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
return schema
|
||||
|
@ -14,11 +14,21 @@ CONF_REASONING_EFFORT = "reasoning_effort"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
CONF_TOP_P = "top_p"
|
||||
CONF_WEB_SEARCH = "web_search"
|
||||
CONF_WEB_SEARCH_USER_LOCATION = "user_location"
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE = "search_context_size"
|
||||
CONF_WEB_SEARCH_CITY = "city"
|
||||
CONF_WEB_SEARCH_REGION = "region"
|
||||
CONF_WEB_SEARCH_COUNTRY = "country"
|
||||
CONF_WEB_SEARCH_TIMEZONE = "timezone"
|
||||
RECOMMENDED_CHAT_MODEL = "gpt-4o-mini"
|
||||
RECOMMENDED_MAX_TOKENS = 150
|
||||
RECOMMENDED_REASONING_EFFORT = "low"
|
||||
RECOMMENDED_TEMPERATURE = 1.0
|
||||
RECOMMENDED_TOP_P = 1.0
|
||||
RECOMMENDED_WEB_SEARCH = False
|
||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
|
||||
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
|
||||
|
||||
UNSUPPORTED_MODELS: list[str] = [
|
||||
"o1-mini",
|
||||
|
@ -23,8 +23,10 @@ from openai.types.responses import (
|
||||
ResponseStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ToolParam,
|
||||
WebSearchToolParam,
|
||||
)
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||
from openai.types.responses.web_search_tool_param import UserLocation
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
@ -43,6 +45,13 @@ from .const import (
|
||||
CONF_REASONING_EFFORT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
@ -50,6 +59,7 @@ from .const import (
|
||||
RECOMMENDED_REASONING_EFFORT,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||
)
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
@ -265,6 +275,25 @@ class OpenAIConversationEntity(
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
if options.get(CONF_WEB_SEARCH):
|
||||
web_search = WebSearchToolParam(
|
||||
type="web_search_preview",
|
||||
search_context_size=options.get(
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
|
||||
),
|
||||
)
|
||||
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||
web_search["user_location"] = UserLocation(
|
||||
type="approximate",
|
||||
city=options.get(CONF_WEB_SEARCH_CITY, ""),
|
||||
region=options.get(CONF_WEB_SEARCH_REGION, ""),
|
||||
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||
)
|
||||
if tools is None:
|
||||
tools = []
|
||||
tools.append(web_search)
|
||||
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
messages = [
|
||||
m
|
||||
|
@ -24,16 +24,23 @@
|
||||
"top_p": "Top P",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||
"recommended": "Recommended model settings",
|
||||
"reasoning_effort": "Reasoning effort"
|
||||
"reasoning_effort": "Reasoning effort",
|
||||
"web_search": "Enable web search",
|
||||
"search_context_size": "Search context size",
|
||||
"user_location": "Include home location"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)"
|
||||
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt (for certain reasoning models)",
|
||||
"web_search": "Allow the model to search the web for the latest information before generating a response",
|
||||
"search_context_size": "High level guidance for the amount of context window space to use for the search",
|
||||
"user_location": "Refine search results based on geography"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"model_not_supported": "This model is not supported, please select a different model"
|
||||
"model_not_supported": "This model is not supported, please select a different model",
|
||||
"web_search_not_supported": "Web search is only supported for gpt-4o and gpt-4o-mini models"
|
||||
}
|
||||
},
|
||||
"selector": {
|
||||
@ -43,6 +50,13 @@
|
||||
"medium": "Medium",
|
||||
"high": "High"
|
||||
}
|
||||
},
|
||||
"search_context_size": {
|
||||
"options": {
|
||||
"low": "Low",
|
||||
"medium": "Medium",
|
||||
"high": "High"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Test the OpenAI Conversation config flow."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from httpx import Response
|
||||
import httpx
|
||||
from openai import APIConnectionError, AuthenticationError, BadRequestError
|
||||
from openai.types.responses import Response, ResponseOutputMessage, ResponseOutputText
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
@ -16,6 +17,13 @@ from homeassistant.components.openai_conversation.const import (
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
@ -117,13 +125,17 @@ async def test_options_unsupported_model(
|
||||
(APIConnectionError(request=None), "cannot_connect"),
|
||||
(
|
||||
AuthenticationError(
|
||||
response=Response(status_code=None, request=""), body=None, message=None
|
||||
response=httpx.Response(status_code=None, request=""),
|
||||
body=None,
|
||||
message=None,
|
||||
),
|
||||
"invalid_auth",
|
||||
),
|
||||
(
|
||||
BadRequestError(
|
||||
response=Response(status_code=None, request=""), body=None, message=None
|
||||
response=httpx.Response(status_code=None, request=""),
|
||||
body=None,
|
||||
message=None,
|
||||
),
|
||||
"unknown",
|
||||
),
|
||||
@ -172,6 +184,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||
CONF_WEB_SEARCH: False,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
},
|
||||
),
|
||||
(
|
||||
@ -183,6 +198,9 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||
CONF_WEB_SEARCH: False,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: False,
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
@ -225,3 +243,105 @@ async def test_options_switching(
|
||||
await hass.async_block_till_done()
|
||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert options["data"] == expected_options
|
||||
|
||||
|
||||
async def test_options_web_search_user_location(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
"""Test fetching user location."""
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
hass.config.country = "US"
|
||||
hass.config.time_zone = "America/Los_Angeles"
|
||||
hass.states.async_set(
|
||||
"zone.home", "0", {"latitude": 37.7749, "longitude": -122.4194}
|
||||
)
|
||||
with patch(
|
||||
"openai.resources.responses.AsyncResponses.create",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create:
|
||||
mock_create.return_value = Response(
|
||||
object="response",
|
||||
id="resp_A",
|
||||
created_at=1700000000,
|
||||
model="gpt-4o-mini",
|
||||
parallel_tool_calls=True,
|
||||
tool_choice="auto",
|
||||
tools=[],
|
||||
output=[
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id="msg_A",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
type="output_text",
|
||||
text='{"city": "San Francisco", "region": "California"}',
|
||||
annotations=[],
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
options = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 1.0,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert (
|
||||
mock_create.call_args.kwargs["input"][0]["content"] == "Where are the following"
|
||||
" coordinates located: (37.7749, -122.4194)?"
|
||||
)
|
||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert options["data"] == {
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 1.0,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
CONF_REASONING_EFFORT: RECOMMENDED_REASONING_EFFORT,
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||
CONF_WEB_SEARCH_CITY: "San Francisco",
|
||||
CONF_WEB_SEARCH_REGION: "California",
|
||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||
}
|
||||
|
||||
|
||||
async def test_options_web_search_unsupported_model(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
"""Test the options form giving error about web search not being available."""
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_CHAT_MODEL: "o1-pro",
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
CONF_WEB_SEARCH: True,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {"web_search": "web_search_not_supported"}
|
||||
|
@ -18,6 +18,7 @@ from openai.types.responses import (
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionWebSearch,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
@ -29,6 +30,9 @@ from openai.types.responses import (
|
||||
ResponseTextConfig,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
)
|
||||
from openai.types.responses.response import IncompleteDetails
|
||||
import pytest
|
||||
@ -36,6 +40,15 @@ from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.openai_conversation.const import (
|
||||
CONF_WEB_SEARCH,
|
||||
CONF_WEB_SEARCH_CITY,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||
CONF_WEB_SEARCH_COUNTRY,
|
||||
CONF_WEB_SEARCH_REGION,
|
||||
CONF_WEB_SEARCH_TIMEZONE,
|
||||
CONF_WEB_SEARCH_USER_LOCATION,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
@ -225,7 +238,6 @@ async def test_incomplete_response(
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
reason: str,
|
||||
message: str,
|
||||
) -> None:
|
||||
@ -301,7 +313,6 @@ async def test_failed_response(
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
error: ResponseError | ResponseErrorEvent,
|
||||
message: str,
|
||||
) -> None:
|
||||
@ -491,6 +502,41 @@ def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEven
|
||||
]
|
||||
|
||||
|
||||
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
||||
"""Create a web search call item."""
|
||||
return [
|
||||
ResponseOutputItemAddedEvent(
|
||||
item=ResponseFunctionWebSearch(
|
||||
id=id, status="in_progress", type="web_search_call"
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.added",
|
||||
),
|
||||
ResponseWebSearchCallInProgressEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
type="response.web_search_call.in_progress",
|
||||
),
|
||||
ResponseWebSearchCallSearchingEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
type="response.web_search_call.searching",
|
||||
),
|
||||
ResponseWebSearchCallCompletedEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
type="response.web_search_call.completed",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ResponseFunctionWebSearch(
|
||||
id=id, status="completed", type="web_search_call"
|
||||
),
|
||||
output_index=output_index,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def test_function_call(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
@ -581,7 +627,6 @@ async def test_function_call_invalid(
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
description: str,
|
||||
messages: tuple[ResponseStreamEvent],
|
||||
) -> None:
|
||||
@ -633,3 +678,60 @@ async def test_assist_api_tools_conversion(
|
||||
|
||||
tools = mock_create_stream.mock_calls[0][2]["tools"]
|
||||
assert tools
|
||||
|
||||
|
||||
async def test_web_search(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
) -> None:
|
||||
"""Test web_search_tool."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
**mock_config_entry.options,
|
||||
CONF_WEB_SEARCH: True,
|
||||
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
||||
CONF_WEB_SEARCH_USER_LOCATION: True,
|
||||
CONF_WEB_SEARCH_CITY: "San Francisco",
|
||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||
CONF_WEB_SEARCH_REGION: "California",
|
||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||
},
|
||||
)
|
||||
await hass.config_entries.async_reload(mock_config_entry.entry_id)
|
||||
|
||||
message = "Home Assistant now supports ChatGPT Search in Assist"
|
||||
mock_create_stream.return_value = [
|
||||
# Initial conversation
|
||||
(
|
||||
*create_web_search_item(id="ws_A", output_index=0),
|
||||
*create_message_item(id="msg_A", text=message, output_index=1),
|
||||
)
|
||||
]
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"What's on the latest news?",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.openai",
|
||||
)
|
||||
|
||||
assert mock_create_stream.mock_calls[0][2]["tools"] == [
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"search_context_size": "low",
|
||||
"user_location": {
|
||||
"type": "approximate",
|
||||
"city": "San Francisco",
|
||||
"region": "California",
|
||||
"country": "US",
|
||||
"timezone": "America/Los_Angeles",
|
||||
},
|
||||
}
|
||||
]
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||
|
Loading…
x
Reference in New Issue
Block a user