mirror of
https://github.com/home-assistant/core.git
synced 2025-05-23 15:27:09 +00:00

* Add Web search to OpenAI Conversation integration * Limit search for gpt-4o models * Add more tests
738 lines
22 KiB
Python
738 lines
22 KiB
Python
"""Tests for the OpenAI integration."""
|
|
|
|
from collections.abc import Generator
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import httpx
|
|
from openai import AuthenticationError, RateLimitError
|
|
from openai.types import ResponseFormatText
|
|
from openai.types.responses import (
|
|
Response,
|
|
ResponseCompletedEvent,
|
|
ResponseContentPartAddedEvent,
|
|
ResponseContentPartDoneEvent,
|
|
ResponseCreatedEvent,
|
|
ResponseError,
|
|
ResponseErrorEvent,
|
|
ResponseFailedEvent,
|
|
ResponseFunctionCallArgumentsDeltaEvent,
|
|
ResponseFunctionCallArgumentsDoneEvent,
|
|
ResponseFunctionToolCall,
|
|
ResponseFunctionWebSearch,
|
|
ResponseIncompleteEvent,
|
|
ResponseInProgressEvent,
|
|
ResponseOutputItemAddedEvent,
|
|
ResponseOutputItemDoneEvent,
|
|
ResponseOutputMessage,
|
|
ResponseOutputText,
|
|
ResponseReasoningItem,
|
|
ResponseStreamEvent,
|
|
ResponseTextConfig,
|
|
ResponseTextDeltaEvent,
|
|
ResponseTextDoneEvent,
|
|
ResponseWebSearchCallCompletedEvent,
|
|
ResponseWebSearchCallInProgressEvent,
|
|
ResponseWebSearchCallSearchingEvent,
|
|
)
|
|
from openai.types.responses.response import IncompleteDetails
|
|
import pytest
|
|
from syrupy.assertion import SnapshotAssertion
|
|
|
|
from homeassistant.components import conversation
|
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
|
from homeassistant.components.openai_conversation.const import (
|
|
CONF_WEB_SEARCH,
|
|
CONF_WEB_SEARCH_CITY,
|
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
|
CONF_WEB_SEARCH_COUNTRY,
|
|
CONF_WEB_SEARCH_REGION,
|
|
CONF_WEB_SEARCH_TIMEZONE,
|
|
CONF_WEB_SEARCH_USER_LOCATION,
|
|
)
|
|
from homeassistant.const import CONF_LLM_HASS_API
|
|
from homeassistant.core import Context, HomeAssistant
|
|
from homeassistant.helpers import intent
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from tests.common import MockConfigEntry
|
|
from tests.components.conversation import (
|
|
MockChatLog,
|
|
mock_chat_log, # noqa: F401
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_create_stream() -> Generator[AsyncMock]:
|
|
"""Mock stream response."""
|
|
|
|
async def mock_generator(events, **kwargs):
|
|
response = Response(
|
|
id="resp_A",
|
|
created_at=1700000000,
|
|
error=None,
|
|
incomplete_details=None,
|
|
instructions=kwargs.get("instructions"),
|
|
metadata=kwargs.get("metadata", {}),
|
|
model=kwargs.get("model", "gpt-4o-mini"),
|
|
object="response",
|
|
output=[],
|
|
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
|
|
temperature=kwargs.get("temperature", 1.0),
|
|
tool_choice=kwargs.get("tool_choice", "auto"),
|
|
tools=kwargs.get("tools"),
|
|
top_p=kwargs.get("top_p", 1.0),
|
|
max_output_tokens=kwargs.get("max_output_tokens", 100000),
|
|
previous_response_id=kwargs.get("previous_response_id"),
|
|
reasoning=kwargs.get("reasoning"),
|
|
status="in_progress",
|
|
text=kwargs.get(
|
|
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
|
|
),
|
|
truncation=kwargs.get("truncation", "disabled"),
|
|
usage=None,
|
|
user=kwargs.get("user"),
|
|
store=kwargs.get("store", True),
|
|
)
|
|
yield ResponseCreatedEvent(
|
|
response=response,
|
|
type="response.created",
|
|
)
|
|
yield ResponseInProgressEvent(
|
|
response=response,
|
|
type="response.in_progress",
|
|
)
|
|
response.status = "completed"
|
|
|
|
for value in events:
|
|
if isinstance(value, ResponseOutputItemDoneEvent):
|
|
response.output.append(value.item)
|
|
elif isinstance(value, IncompleteDetails):
|
|
response.status = "incomplete"
|
|
response.incomplete_details = value
|
|
break
|
|
if isinstance(value, ResponseError):
|
|
response.status = "failed"
|
|
response.error = value
|
|
break
|
|
|
|
yield value
|
|
|
|
if isinstance(value, ResponseErrorEvent):
|
|
return
|
|
|
|
if response.status == "incomplete":
|
|
yield ResponseIncompleteEvent(
|
|
response=response,
|
|
type="response.incomplete",
|
|
)
|
|
elif response.status == "failed":
|
|
yield ResponseFailedEvent(
|
|
response=response,
|
|
type="response.failed",
|
|
)
|
|
else:
|
|
yield ResponseCompletedEvent(
|
|
response=response,
|
|
type="response.completed",
|
|
)
|
|
|
|
with patch(
|
|
"openai.resources.responses.AsyncResponses.create",
|
|
AsyncMock(),
|
|
) as mock_create:
|
|
mock_create.side_effect = lambda **kwargs: mock_generator(
|
|
mock_create.return_value.pop(0), **kwargs
|
|
)
|
|
|
|
yield mock_create
|
|
|
|
|
|
async def test_entity(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test entity properties."""
|
|
state = hass.states.get("conversation.openai")
|
|
assert state
|
|
assert state.attributes["supported_features"] == 0
|
|
|
|
hass.config_entries.async_update_entry(
|
|
mock_config_entry,
|
|
options={
|
|
**mock_config_entry.options,
|
|
CONF_LLM_HASS_API: "assist",
|
|
},
|
|
)
|
|
await hass.config_entries.async_reload(mock_config_entry.entry_id)
|
|
|
|
state = hass.states.get("conversation.openai")
|
|
assert state
|
|
assert (
|
|
state.attributes["supported_features"]
|
|
== conversation.ConversationEntityFeature.CONTROL
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("exception", "message"),
|
|
[
|
|
(
|
|
RateLimitError(
|
|
response=httpx.Response(status_code=429, request=""),
|
|
body=None,
|
|
message=None,
|
|
),
|
|
"Rate limited or insufficient funds",
|
|
),
|
|
(
|
|
AuthenticationError(
|
|
response=httpx.Response(status_code=401, request=""),
|
|
body=None,
|
|
message=None,
|
|
),
|
|
"Error talking to OpenAI",
|
|
),
|
|
],
|
|
)
|
|
async def test_error_handling(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
exception,
|
|
message,
|
|
) -> None:
|
|
"""Test that we handle errors when calling completion API."""
|
|
with patch(
|
|
"openai.resources.responses.AsyncResponses.create",
|
|
new_callable=AsyncMock,
|
|
side_effect=exception,
|
|
):
|
|
result = await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("reason", "message"),
|
|
[
|
|
(
|
|
"max_output_tokens",
|
|
"max output tokens reached",
|
|
),
|
|
(
|
|
"content_filter",
|
|
"content filter triggered",
|
|
),
|
|
(
|
|
None,
|
|
"unknown reason",
|
|
),
|
|
],
|
|
)
|
|
async def test_incomplete_response(
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
reason: str,
|
|
message: str,
|
|
) -> None:
|
|
"""Test handling early model stop."""
|
|
# Incomplete details received after some content is generated
|
|
mock_create_stream.return_value = [
|
|
(
|
|
# Start message
|
|
*create_message_item(
|
|
id="msg_A",
|
|
text=["Once upon", " a time, ", "there was "],
|
|
output_index=0,
|
|
),
|
|
# Length limit or content filter
|
|
IncompleteDetails(reason=reason),
|
|
)
|
|
]
|
|
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"Please tell me a big story",
|
|
"mock-conversation-id",
|
|
Context(),
|
|
agent_id="conversation.openai",
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert (
|
|
result.response.speech["plain"]["speech"]
|
|
== f"OpenAI response incomplete: {message}"
|
|
), result.response.speech
|
|
|
|
# Incomplete details received before any content is generated
|
|
mock_create_stream.return_value = [
|
|
(
|
|
# Start generating response
|
|
*create_reasoning_item(id="rs_A", output_index=0),
|
|
# Length limit or content filter
|
|
IncompleteDetails(reason=reason),
|
|
)
|
|
]
|
|
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"please tell me a big story",
|
|
"mock-conversation-id",
|
|
Context(),
|
|
agent_id="conversation.openai",
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert (
|
|
result.response.speech["plain"]["speech"]
|
|
== f"OpenAI response incomplete: {message}"
|
|
), result.response.speech
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("error", "message"),
|
|
[
|
|
(
|
|
ResponseError(code="rate_limit_exceeded", message="Rate limit exceeded"),
|
|
"OpenAI response failed: Rate limit exceeded",
|
|
),
|
|
(
|
|
ResponseErrorEvent(type="error", message="Some error"),
|
|
"OpenAI response error: Some error",
|
|
),
|
|
],
|
|
)
|
|
async def test_failed_response(
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
error: ResponseError | ResponseErrorEvent,
|
|
message: str,
|
|
) -> None:
|
|
"""Test handling failed and error responses."""
|
|
mock_create_stream.return_value = [(error,)]
|
|
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"next natural number please",
|
|
"mock-conversation-id",
|
|
Context(),
|
|
agent_id="conversation.openai",
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
|
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
|
|
|
|
|
async def test_conversation_agent(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
) -> None:
|
|
"""Test OpenAIAgent."""
|
|
agent = conversation.get_agent_manager(hass).async_get_agent(
|
|
mock_config_entry.entry_id
|
|
)
|
|
assert agent.supported_languages == "*"
|
|
|
|
|
|
def create_message_item(
|
|
id: str, text: str | list[str], output_index: int
|
|
) -> list[ResponseStreamEvent]:
|
|
"""Create a message item."""
|
|
if isinstance(text, str):
|
|
text = [text]
|
|
|
|
content = ResponseOutputText(annotations=[], text="", type="output_text")
|
|
events = [
|
|
ResponseOutputItemAddedEvent(
|
|
item=ResponseOutputMessage(
|
|
id=id,
|
|
content=[],
|
|
type="message",
|
|
role="assistant",
|
|
status="in_progress",
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.added",
|
|
),
|
|
ResponseContentPartAddedEvent(
|
|
content_index=0,
|
|
item_id=id,
|
|
output_index=output_index,
|
|
part=content,
|
|
type="response.content_part.added",
|
|
),
|
|
]
|
|
|
|
content.text = "".join(text)
|
|
events.extend(
|
|
ResponseTextDeltaEvent(
|
|
content_index=0,
|
|
delta=delta,
|
|
item_id=id,
|
|
output_index=output_index,
|
|
type="response.output_text.delta",
|
|
)
|
|
for delta in text
|
|
)
|
|
|
|
events.extend(
|
|
[
|
|
ResponseTextDoneEvent(
|
|
content_index=0,
|
|
item_id=id,
|
|
output_index=output_index,
|
|
text="".join(text),
|
|
type="response.output_text.done",
|
|
),
|
|
ResponseContentPartDoneEvent(
|
|
content_index=0,
|
|
item_id=id,
|
|
output_index=output_index,
|
|
part=content,
|
|
type="response.content_part.done",
|
|
),
|
|
ResponseOutputItemDoneEvent(
|
|
item=ResponseOutputMessage(
|
|
id=id,
|
|
content=[content],
|
|
role="assistant",
|
|
status="completed",
|
|
type="message",
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.done",
|
|
),
|
|
]
|
|
)
|
|
|
|
return events
|
|
|
|
|
|
def create_function_tool_call_item(
|
|
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
|
|
) -> list[ResponseStreamEvent]:
|
|
"""Create a function tool call item."""
|
|
if isinstance(arguments, str):
|
|
arguments = [arguments]
|
|
|
|
events = [
|
|
ResponseOutputItemAddedEvent(
|
|
item=ResponseFunctionToolCall(
|
|
id=id,
|
|
arguments="",
|
|
call_id=call_id,
|
|
name=name,
|
|
type="function_call",
|
|
status="in_progress",
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.added",
|
|
)
|
|
]
|
|
|
|
events.extend(
|
|
ResponseFunctionCallArgumentsDeltaEvent(
|
|
delta=delta,
|
|
item_id=id,
|
|
output_index=output_index,
|
|
type="response.function_call_arguments.delta",
|
|
)
|
|
for delta in arguments
|
|
)
|
|
|
|
events.append(
|
|
ResponseFunctionCallArgumentsDoneEvent(
|
|
arguments="".join(arguments),
|
|
item_id=id,
|
|
output_index=output_index,
|
|
type="response.function_call_arguments.done",
|
|
)
|
|
)
|
|
|
|
events.append(
|
|
ResponseOutputItemDoneEvent(
|
|
item=ResponseFunctionToolCall(
|
|
id=id,
|
|
arguments="".join(arguments),
|
|
call_id=call_id,
|
|
name=name,
|
|
type="function_call",
|
|
status="completed",
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.done",
|
|
)
|
|
)
|
|
|
|
return events
|
|
|
|
|
|
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
|
"""Create a reasoning item."""
|
|
return [
|
|
ResponseOutputItemAddedEvent(
|
|
item=ResponseReasoningItem(
|
|
id=id,
|
|
summary=[],
|
|
type="reasoning",
|
|
status=None,
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.added",
|
|
),
|
|
ResponseOutputItemDoneEvent(
|
|
item=ResponseReasoningItem(
|
|
id=id,
|
|
summary=[],
|
|
type="reasoning",
|
|
status=None,
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.done",
|
|
),
|
|
]
|
|
|
|
|
|
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
|
"""Create a web search call item."""
|
|
return [
|
|
ResponseOutputItemAddedEvent(
|
|
item=ResponseFunctionWebSearch(
|
|
id=id, status="in_progress", type="web_search_call"
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.added",
|
|
),
|
|
ResponseWebSearchCallInProgressEvent(
|
|
item_id=id,
|
|
output_index=output_index,
|
|
type="response.web_search_call.in_progress",
|
|
),
|
|
ResponseWebSearchCallSearchingEvent(
|
|
item_id=id,
|
|
output_index=output_index,
|
|
type="response.web_search_call.searching",
|
|
),
|
|
ResponseWebSearchCallCompletedEvent(
|
|
item_id=id,
|
|
output_index=output_index,
|
|
type="response.web_search_call.completed",
|
|
),
|
|
ResponseOutputItemDoneEvent(
|
|
item=ResponseFunctionWebSearch(
|
|
id=id, status="completed", type="web_search_call"
|
|
),
|
|
output_index=output_index,
|
|
type="response.output_item.done",
|
|
),
|
|
]
|
|
|
|
|
|
async def test_function_call(
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
mock_chat_log: MockChatLog, # noqa: F811
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test function call from the assistant."""
|
|
mock_create_stream.return_value = [
|
|
# Initial conversation
|
|
(
|
|
# Wait for the model to think
|
|
*create_reasoning_item(id="rs_A", output_index=0),
|
|
# First tool call
|
|
*create_function_tool_call_item(
|
|
id="fc_1",
|
|
arguments=['{"para', 'm1":"call1"}'],
|
|
call_id="call_call_1",
|
|
name="test_tool",
|
|
output_index=1,
|
|
),
|
|
# Second tool call
|
|
*create_function_tool_call_item(
|
|
id="fc_2",
|
|
arguments='{"param1":"call2"}',
|
|
call_id="call_call_2",
|
|
name="test_tool",
|
|
output_index=2,
|
|
),
|
|
),
|
|
# Response after tool responses
|
|
create_message_item(id="msg_A", text="Cool", output_index=0),
|
|
]
|
|
mock_chat_log.mock_tool_results(
|
|
{
|
|
"call_call_1": "value1",
|
|
"call_call_2": "value2",
|
|
}
|
|
)
|
|
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"Please call the test function",
|
|
mock_chat_log.conversation_id,
|
|
Context(),
|
|
agent_id="conversation.openai",
|
|
)
|
|
|
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
# Don't test the prompt, as it's not deterministic
|
|
assert mock_chat_log.content[1:] == snapshot
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("description", "messages"),
|
|
[
|
|
(
|
|
"Test function call started with missing arguments",
|
|
(
|
|
*create_function_tool_call_item(
|
|
id="fc_1",
|
|
arguments=[],
|
|
call_id="call_call_1",
|
|
name="test_tool",
|
|
output_index=0,
|
|
),
|
|
*create_message_item(id="msg_A", text="Cool", output_index=1),
|
|
),
|
|
),
|
|
(
|
|
"Test invalid JSON",
|
|
(
|
|
*create_function_tool_call_item(
|
|
id="fc_1",
|
|
arguments=['{"para'],
|
|
call_id="call_call_1",
|
|
name="test_tool",
|
|
output_index=0,
|
|
),
|
|
*create_message_item(id="msg_A", text="Cool", output_index=1),
|
|
),
|
|
),
|
|
],
|
|
)
|
|
async def test_function_call_invalid(
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
description: str,
|
|
messages: tuple[ResponseStreamEvent],
|
|
) -> None:
|
|
"""Test function call containing invalid data."""
|
|
mock_create_stream.return_value = [messages]
|
|
|
|
with pytest.raises(ValueError):
|
|
await conversation.async_converse(
|
|
hass,
|
|
"Please call the test function",
|
|
"mock-conversation-id",
|
|
Context(),
|
|
agent_id="conversation.openai",
|
|
)
|
|
|
|
|
|
async def test_assist_api_tools_conversion(
|
|
hass: HomeAssistant,
|
|
mock_config_entry_with_assist: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream,
|
|
) -> None:
|
|
"""Test that we are able to convert actual tools from Assist API."""
|
|
for component in (
|
|
"calendar",
|
|
"climate",
|
|
"cover",
|
|
"humidifier",
|
|
"intent",
|
|
"light",
|
|
"media_player",
|
|
"script",
|
|
"shopping_list",
|
|
"todo",
|
|
"vacuum",
|
|
"weather",
|
|
):
|
|
assert await async_setup_component(hass, component, {})
|
|
hass.states.async_set(f"{component}.test", "on")
|
|
async_expose_entity(hass, "conversation", f"{component}.test", True)
|
|
|
|
mock_create_stream.return_value = [
|
|
create_message_item(id="msg_A", text="Cool", output_index=0)
|
|
]
|
|
|
|
await conversation.async_converse(
|
|
hass, "hello", None, Context(), agent_id="conversation.openai"
|
|
)
|
|
|
|
tools = mock_create_stream.mock_calls[0][2]["tools"]
|
|
assert tools
|
|
|
|
|
|
async def test_web_search(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream,
|
|
mock_chat_log: MockChatLog, # noqa: F811
|
|
) -> None:
|
|
"""Test web_search_tool."""
|
|
hass.config_entries.async_update_entry(
|
|
mock_config_entry,
|
|
options={
|
|
**mock_config_entry.options,
|
|
CONF_WEB_SEARCH: True,
|
|
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
|
|
CONF_WEB_SEARCH_USER_LOCATION: True,
|
|
CONF_WEB_SEARCH_CITY: "San Francisco",
|
|
CONF_WEB_SEARCH_COUNTRY: "US",
|
|
CONF_WEB_SEARCH_REGION: "California",
|
|
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
|
},
|
|
)
|
|
await hass.config_entries.async_reload(mock_config_entry.entry_id)
|
|
|
|
message = "Home Assistant now supports ChatGPT Search in Assist"
|
|
mock_create_stream.return_value = [
|
|
# Initial conversation
|
|
(
|
|
*create_web_search_item(id="ws_A", output_index=0),
|
|
*create_message_item(id="msg_A", text=message, output_index=1),
|
|
)
|
|
]
|
|
|
|
result = await conversation.async_converse(
|
|
hass,
|
|
"What's on the latest news?",
|
|
mock_chat_log.conversation_id,
|
|
Context(),
|
|
agent_id="conversation.openai",
|
|
)
|
|
|
|
assert mock_create_stream.mock_calls[0][2]["tools"] == [
|
|
{
|
|
"type": "web_search_preview",
|
|
"search_context_size": "low",
|
|
"user_location": {
|
|
"type": "approximate",
|
|
"city": "San Francisco",
|
|
"region": "California",
|
|
"country": "US",
|
|
"timezone": "America/Los_Angeles",
|
|
},
|
|
}
|
|
]
|
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|