core/tests/components/openai_conversation/test_conversation.py
Denis Shulyaka c974285490
Add Web search to OpenAI Conversation integration (#141426)
* Add Web search to OpenAI Conversation integration

* Limit search for gpt-4o models

* Add more tests
2025-03-26 09:36:05 -04:00

738 lines
22 KiB
Python

"""Tests for the OpenAI integration."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
import httpx
from openai import AuthenticationError, RateLimitError
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseStreamEvent,
ResponseTextConfig,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses.response import IncompleteDetails
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.openai_conversation.const import (
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
mock_chat_log, # noqa: F401
)
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools"),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create
async def test_entity(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test entity properties."""
state = hass.states.get("conversation.openai")
assert state
assert state.attributes["supported_features"] == 0
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: "assist",
},
)
await hass.config_entries.async_reload(mock_config_entry.entry_id)
state = hass.states.get("conversation.openai")
assert state
assert (
state.attributes["supported_features"]
== conversation.ConversationEntityFeature.CONTROL
)
@pytest.mark.parametrize(
("exception", "message"),
[
(
RateLimitError(
response=httpx.Response(status_code=429, request=""),
body=None,
message=None,
),
"Rate limited or insufficient funds",
),
(
AuthenticationError(
response=httpx.Response(status_code=401, request=""),
body=None,
message=None,
),
"Error talking to OpenAI",
),
],
)
async def test_error_handling(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
exception,
message,
) -> None:
"""Test that we handle errors when calling completion API."""
with patch(
"openai.resources.responses.AsyncResponses.create",
new_callable=AsyncMock,
side_effect=exception,
):
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.speech["plain"]["speech"] == message, result.response.speech
@pytest.mark.parametrize(
("reason", "message"),
[
(
"max_output_tokens",
"max output tokens reached",
),
(
"content_filter",
"content filter triggered",
),
(
None,
"unknown reason",
),
],
)
async def test_incomplete_response(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
reason: str,
message: str,
) -> None:
"""Test handling early model stop."""
# Incomplete details received after some content is generated
mock_create_stream.return_value = [
(
# Start message
*create_message_item(
id="msg_A",
text=["Once upon", " a time, ", "there was "],
output_index=0,
),
# Length limit or content filter
IncompleteDetails(reason=reason),
)
]
result = await conversation.async_converse(
hass,
"Please tell me a big story",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert (
result.response.speech["plain"]["speech"]
== f"OpenAI response incomplete: {message}"
), result.response.speech
# Incomplete details received before any content is generated
mock_create_stream.return_value = [
(
# Start generating response
*create_reasoning_item(id="rs_A", output_index=0),
# Length limit or content filter
IncompleteDetails(reason=reason),
)
]
result = await conversation.async_converse(
hass,
"please tell me a big story",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert (
result.response.speech["plain"]["speech"]
== f"OpenAI response incomplete: {message}"
), result.response.speech
@pytest.mark.parametrize(
("error", "message"),
[
(
ResponseError(code="rate_limit_exceeded", message="Rate limit exceeded"),
"OpenAI response failed: Rate limit exceeded",
),
(
ResponseErrorEvent(type="error", message="Some error"),
"OpenAI response error: Some error",
),
],
)
async def test_failed_response(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
error: ResponseError | ResponseErrorEvent,
message: str,
) -> None:
"""Test handling failed and error responses."""
mock_create_stream.return_value = [(error,)]
result = await conversation.async_converse(
hass,
"next natural number please",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.speech["plain"]["speech"] == message, result.response.speech
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test OpenAIAgent."""
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == "*"
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]:
"""Create a function tool call item."""
if isinstance(arguments, str):
arguments = [arguments]
events = [
ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="",
call_id=call_id,
name=name,
type="function_call",
status="in_progress",
),
output_index=output_index,
type="response.output_item.added",
)
]
events.extend(
ResponseFunctionCallArgumentsDeltaEvent(
delta=delta,
item_id=id,
output_index=output_index,
type="response.function_call_arguments.delta",
)
for delta in arguments
)
events.append(
ResponseFunctionCallArgumentsDoneEvent(
arguments="".join(arguments),
item_id=id,
output_index=output_index,
type="response.function_call_arguments.done",
)
)
events.append(
ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="".join(arguments),
call_id=call_id,
name=name,
type="function_call",
status="completed",
),
output_index=output_index,
type="response.output_item.done",
)
)
return events
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a reasoning item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
),
output_index=output_index,
type="response.output_item.added",
),
ResponseOutputItemDoneEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
),
output_index=output_index,
type="response.output_item.done",
),
]
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a web search call item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseFunctionWebSearch(
id=id, status="in_progress", type="web_search_call"
),
output_index=output_index,
type="response.output_item.added",
),
ResponseWebSearchCallInProgressEvent(
item_id=id,
output_index=output_index,
type="response.web_search_call.in_progress",
),
ResponseWebSearchCallSearchingEvent(
item_id=id,
output_index=output_index,
type="response.web_search_call.searching",
),
ResponseWebSearchCallCompletedEvent(
item_id=id,
output_index=output_index,
type="response.web_search_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseFunctionWebSearch(
id=id, status="completed", type="web_search_call"
),
output_index=output_index,
type="response.output_item.done",
),
]
async def test_function_call(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
mock_chat_log: MockChatLog, # noqa: F811
snapshot: SnapshotAssertion,
) -> None:
"""Test function call from the assistant."""
mock_create_stream.return_value = [
# Initial conversation
(
# Wait for the model to think
*create_reasoning_item(id="rs_A", output_index=0),
# First tool call
*create_function_tool_call_item(
id="fc_1",
arguments=['{"para', 'm1":"call1"}'],
call_id="call_call_1",
name="test_tool",
output_index=1,
),
# Second tool call
*create_function_tool_call_item(
id="fc_2",
arguments='{"param1":"call2"}',
call_id="call_call_2",
name="test_tool",
output_index=2,
),
),
# Response after tool responses
create_message_item(id="msg_A", text="Cool", output_index=0),
]
mock_chat_log.mock_tool_results(
{
"call_call_1": "value1",
"call_call_2": "value2",
}
)
result = await conversation.async_converse(
hass,
"Please call the test function",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
# Don't test the prompt, as it's not deterministic
assert mock_chat_log.content[1:] == snapshot
@pytest.mark.parametrize(
("description", "messages"),
[
(
"Test function call started with missing arguments",
(
*create_function_tool_call_item(
id="fc_1",
arguments=[],
call_id="call_call_1",
name="test_tool",
output_index=0,
),
*create_message_item(id="msg_A", text="Cool", output_index=1),
),
),
(
"Test invalid JSON",
(
*create_function_tool_call_item(
id="fc_1",
arguments=['{"para'],
call_id="call_call_1",
name="test_tool",
output_index=0,
),
*create_message_item(id="msg_A", text="Cool", output_index=1),
),
),
],
)
async def test_function_call_invalid(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
description: str,
messages: tuple[ResponseStreamEvent],
) -> None:
"""Test function call containing invalid data."""
mock_create_stream.return_value = [messages]
with pytest.raises(ValueError):
await conversation.async_converse(
hass,
"Please call the test function",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
)
async def test_assist_api_tools_conversion(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
mock_create_stream,
) -> None:
"""Test that we are able to convert actual tools from Assist API."""
for component in (
"calendar",
"climate",
"cover",
"humidifier",
"intent",
"light",
"media_player",
"script",
"shopping_list",
"todo",
"vacuum",
"weather",
):
assert await async_setup_component(hass, component, {})
hass.states.async_set(f"{component}.test", "on")
async_expose_entity(hass, "conversation", f"{component}.test", True)
mock_create_stream.return_value = [
create_message_item(id="msg_A", text="Cool", output_index=0)
]
await conversation.async_converse(
hass, "hello", None, Context(), agent_id="conversation.openai"
)
tools = mock_create_stream.mock_calls[0][2]["tools"]
assert tools
async def test_web_search(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
mock_create_stream,
mock_chat_log: MockChatLog, # noqa: F811
) -> None:
"""Test web_search_tool."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: True,
CONF_WEB_SEARCH_CITY: "San Francisco",
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_REGION: "California",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
},
)
await hass.config_entries.async_reload(mock_config_entry.entry_id)
message = "Home Assistant now supports ChatGPT Search in Assist"
mock_create_stream.return_value = [
# Initial conversation
(
*create_web_search_item(id="ws_A", output_index=0),
*create_message_item(id="msg_A", text=message, output_index=1),
)
]
result = await conversation.async_converse(
hass,
"What's on the latest news?",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai",
)
assert mock_create_stream.mock_calls[0][2]["tools"] == [
{
"type": "web_search_preview",
"search_context_size": "low",
"user_location": {
"type": "approximate",
"city": "San Francisco",
"region": "California",
"country": "US",
"timezone": "America/Los_Angeles",
},
}
]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech["plain"]["speech"] == message, result.response.speech