Add OpenAI AI Task entity (#148295)

This commit is contained in:
Paulus Schoutsen 2025-07-10 23:08:56 +02:00 committed by GitHub
parent f0a636949a
commit 0e09a47476
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1152 additions and 463 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from types import MappingProxyType
import openai import openai
from openai.types.images_response import ImagesResponse from openai.types.images_response import ImagesResponse
@ -45,9 +46,11 @@ from .const import (
CONF_REASONING_EFFORT, CONF_REASONING_EFFORT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_AI_TASK_NAME,
DEFAULT_NAME, DEFAULT_NAME,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT,
@ -59,7 +62,7 @@ from .entity import async_prepare_files_for_prompt
SERVICE_GENERATE_IMAGE = "generate_image" SERVICE_GENERATE_IMAGE = "generate_image"
SERVICE_GENERATE_CONTENT = "generate_content" SERVICE_GENERATE_CONTENT = "generate_content"
PLATFORMS = (Platform.CONVERSATION,) PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient] type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
@ -153,7 +156,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
EasyInputMessageParam(type="message", role="user", content=content) EasyInputMessageParam(type="message", role="user", content=content)
] ]
try:
model_args = { model_args = {
"model": model, "model": model,
"input": messages, "input": messages,
@ -175,6 +177,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
) )
} }
try:
response: Response = await client.responses.create(**model_args) response: Response = await client.responses.create(**model_args)
except openai.OpenAIError as err: except openai.OpenAIError as err:
@ -361,6 +364,18 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
hass.config_entries.async_update_entry(entry, minor_version=2) hass.config_entries.async_update_entry(entry, minor_version=2)
if entry.version == 2 and entry.minor_version == 2:
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
subentry_type="ai_task_data",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
)
hass.config_entries.async_update_entry(entry, minor_version=3)
LOGGER.debug( LOGGER.debug(
"Migration to version %s:%s successful", entry.version, entry.minor_version "Migration to version %s:%s successful", entry.version, entry.minor_version
) )

View File

@ -0,0 +1,77 @@
"""AI Task integration for OpenAI."""
from __future__ import annotations
from json import JSONDecodeError
import logging
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .entity import OpenAIBaseLLMEntity
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue
async_add_entities(
[OpenAITaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class OpenAITaskEntity(
ai_task.AITaskEntity,
OpenAIBaseLLMEntity,
):
"""OpenAI AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.name, task.structure)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)
text = chat_log.content[-1].content or ""
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with OpenAI structured response") from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)

View File

@ -55,9 +55,12 @@ from .const import (
CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
@ -77,12 +80,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
@ -99,7 +96,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OpenAI Conversation.""" """Handle a config flow for OpenAI Conversation."""
VERSION = 2 VERSION = 2
MINOR_VERSION = 2 MINOR_VERSION = 3
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@ -129,10 +126,16 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
subentries=[ subentries=[
{ {
"subentry_type": "conversation", "subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS, "data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME, "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
} },
{
"subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
], ],
) )
@ -146,11 +149,14 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
cls, config_entry: ConfigEntry cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]: ) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration.""" """Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler} return {
"conversation": OpenAISubentryFlowHandler,
"ai_task_data": OpenAISubentryFlowHandler,
}
class ConversationSubentryFlowHandler(ConfigSubentryFlow): class OpenAISubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries.""" """Flow for managing OpenAI subentries."""
last_rendered_recommended = False last_rendered_recommended = False
options: dict[str, Any] options: dict[str, Any]
@ -164,7 +170,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult: ) -> SubentryFlowResult:
"""Add a subentry.""" """Add a subentry."""
self.options = RECOMMENDED_OPTIONS.copy() if self._subentry_type == "ai_task_data":
self.options = RECOMMENDED_AI_TASK_OPTIONS.copy()
else:
self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
return await self.async_step_init() return await self.async_step_init()
async def async_step_reconfigure( async def async_step_reconfigure(
@ -181,6 +190,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
# abort if entry is not loaded # abort if entry is not loaded
if self._get_entry().state != ConfigEntryState.LOADED: if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded") return self.async_abort(reason="entry_not_loaded")
options = self.options options = self.options
hass_apis: list[SelectOptionDict] = [ hass_apis: list[SelectOptionDict] = [
@ -198,10 +208,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
step_schema: VolDictType = {} step_schema: VolDictType = {}
if self._is_new: if self._is_new:
step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = ( if self._subentry_type == "ai_task_data":
str default_name = DEFAULT_AI_TASK_NAME
) else:
default_name = DEFAULT_CONVERSATION_NAME
step_schema[vol.Required(CONF_NAME, default=default_name)] = str
if self._subentry_type == "conversation":
step_schema.update( step_schema.update(
{ {
vol.Optional( vol.Optional(
@ -215,12 +228,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
vol.Optional(CONF_LLM_HASS_API): SelectSelector( vol.Optional(CONF_LLM_HASS_API): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True) SelectSelectorConfig(options=hass_apis, multiple=True)
), ),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
} }
) )
step_schema[
vol.Required(CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False))
] = bool
if user_input is not None: if user_input is not None:
if not user_input.get(CONF_LLM_HASS_API): if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None) user_input.pop(CONF_LLM_HASS_API, None)
@ -320,7 +334,9 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
elif CONF_REASONING_EFFORT in options: elif CONF_REASONING_EFFORT in options:
options.pop(CONF_REASONING_EFFORT) options.pop(CONF_REASONING_EFFORT)
if not model.startswith(tuple(UNSUPPORTED_WEB_SEARCH_MODELS)): if self._subentry_type == "conversation" and not model.startswith(
tuple(UNSUPPORTED_WEB_SEARCH_MODELS)
):
step_schema.update( step_schema.update(
{ {
vol.Optional( vol.Optional(
@ -362,7 +378,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
if not step_schema: if not step_schema:
if self._is_new: if self._is_new:
return self.async_create_entry( return self.async_create_entry(
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME), title=options.pop(CONF_NAME),
data=options, data=options,
) )
return self.async_update_and_abort( return self.async_update_and_abort(
@ -384,7 +400,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
options.update(user_input) options.update(user_input)
if self._is_new: if self._is_new:
return self.async_create_entry( return self.async_create_entry(
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME), title=options.pop(CONF_NAME),
data=options, data=options,
) )
return self.async_update_and_abort( return self.async_update_and_abort(

View File

@ -2,10 +2,14 @@
import logging import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
LOGGER: logging.Logger = logging.getLogger(__package__) LOGGER: logging.Logger = logging.getLogger(__package__)
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation" DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
DEFAULT_NAME = "OpenAI Conversation" DEFAULT_NAME = "OpenAI Conversation"
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
@ -51,3 +55,12 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [
"o1", "o1",
"o3-mini", "o3-mini",
] ]
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
RECOMMENDED_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True,
}

View File

@ -39,6 +39,7 @@ from openai.types.responses import (
) )
from openai.types.responses.response_input_param import FunctionCallOutput from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.web_search_tool_param import UserLocation from openai.types.responses.web_search_tool_param import UserLocation
import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import conversation from homeassistant.components import conversation
@ -47,6 +48,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.util import slugify
from .const import ( from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
@ -79,6 +81,47 @@ if TYPE_CHECKING:
MAX_TOOL_ITERATIONS = 10 MAX_TOOL_ITERATIONS = 10
def _adjust_schema(schema: dict[str, Any]) -> None:
"""Adjust the schema to be compatible with OpenAI API."""
if schema["type"] == "object":
if "properties" not in schema:
return
if "required" not in schema:
schema["required"] = []
# Ensure all properties are required
for prop, prop_info in schema["properties"].items():
_adjust_schema(prop_info)
if prop not in schema["required"]:
prop_info["type"] = [prop_info["type"], "null"]
schema["required"].append(prop)
elif schema["type"] == "array":
if "items" not in schema:
return
_adjust_schema(schema["items"])
def _format_structured_output(
schema: vol.Schema, llm_api: llm.APIInstance | None
) -> dict[str, Any]:
"""Format the schema to be compatible with OpenAI API."""
result: dict[str, Any] = convert(
schema,
custom_serializer=(
llm_api.custom_serializer if llm_api else llm.selector_serializer
),
)
_adjust_schema(result)
result["strict"] = True
result["additionalProperties"] = False
return result
def _format_tool( def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> FunctionToolParam: ) -> FunctionToolParam:
@ -243,6 +286,8 @@ class OpenAIBaseLLMEntity(Entity):
async def _async_handle_chat_log( async def _async_handle_chat_log(
self, self,
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
structure_name: str | None = None,
structure: vol.Schema | None = None,
) -> None: ) -> None:
"""Generate an answer for the chat log.""" """Generate an answer for the chat log."""
options = self.subentry.data options = self.subentry.data
@ -273,23 +318,10 @@ class OpenAIBaseLLMEntity(Entity):
tools = [] tools = []
tools.append(web_search) tools.append(web_search)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args = { model_args = {
"model": model, "model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
"input": messages, "input": [],
"max_output_tokens": options.get( "max_output_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id, "user": chat_log.conversation_id,
@ -299,13 +331,34 @@ class OpenAIBaseLLMEntity(Entity):
if tools: if tools:
model_args["tools"] = tools model_args["tools"] = tools
if model.startswith("o"): if model_args["model"].startswith("o"):
model_args["reasoning"] = { model_args["reasoning"] = {
"effort": options.get( "effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
) )
} }
model_args["include"] = ["reasoning.encrypted_content"] else:
model_args["store"] = False
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
if structure and structure_name:
model_args["text"] = {
"format": {
"type": "json_schema",
"name": slugify(structure_name),
"schema": _format_structured_output(structure, chat_log.llm_api),
},
}
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args["input"] = messages
try: try:
result = await client.responses.create(**model_args) result = await client.responses.create(**model_args)

View File

@ -68,6 +68,52 @@
"error": { "error": {
"model_not_supported": "This model is not supported, please select a different model" "model_not_supported": "This model is not supported, please select a different model"
} }
},
"ai_task_data": {
"initiate_flow": {
"user": "Add Generate data with AI service",
"reconfigure": "Reconfigure Generate data with AI service"
},
"entry_type": "Generate data with AI service",
"step": {
"init": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::openai_conversation::config_subentries::conversation::step::init::data::recommended%]"
}
},
"advanced": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::title%]",
"data": {
"chat_model": "[%key:common::generic::model%]",
"max_tokens": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::max_tokens%]",
"temperature": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::temperature%]",
"top_p": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::top_p%]"
}
},
"model": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::model::title%]",
"data": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::user_location%]"
},
"data_description": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::user_location%]"
}
}
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
"entry_not_loaded": "[%key:component::openai_conversation::config_subentries::conversation::abort::entry_not_loaded%]"
},
"error": {
"model_not_supported": "[%key:component::openai_conversation::config_subentries::conversation::error::model_not_supported%]"
}
} }
}, },
"selector": { "selector": {

View File

@ -1 +1,241 @@
"""Tests for the OpenAI Conversation integration.""" """Tests for the OpenAI Conversation integration."""
from openai.types.responses import (
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses.response_function_web_search import ActionSearch
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
sequence_number=0,
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]:
"""Create a function tool call item."""
if isinstance(arguments, str):
arguments = [arguments]
events = [
ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="",
call_id=call_id,
name=name,
type="function_call",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
)
]
events.extend(
ResponseFunctionCallArgumentsDeltaEvent(
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.delta",
)
for delta in arguments
)
events.append(
ResponseFunctionCallArgumentsDoneEvent(
arguments="".join(arguments),
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.done",
)
)
events.append(
ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="".join(arguments),
call_id=call_id,
name=name,
type="function_call",
status="completed",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
)
)
return events
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a reasoning item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAA",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseOutputItemDoneEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAABBB",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a web search call item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseFunctionWebSearch(
id=id,
status="in_progress",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseWebSearchCallInProgressEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.in_progress",
),
ResponseWebSearchCallSearchingEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.searching",
),
ResponseWebSearchCallCompletedEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseFunctionWebSearch(
id=id,
status="completed",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]

View File

@ -1,13 +1,30 @@
"""Tests helpers.""" """Tests helpers."""
from collections.abc import Generator
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import AsyncMock, patch
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemDoneEvent,
ResponseTextConfig,
)
from openai.types.responses.response import IncompleteDetails
import pytest import pytest
from homeassistant.components.openai_conversation.const import ( from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
RECOMMENDED_AI_TASK_OPTIONS,
) )
from homeassistant.config_entries import ConfigSubentryData from homeassistant.config_entries import ConfigSubentryData
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
@ -19,14 +36,14 @@ from tests.common import MockConfigEntry
@pytest.fixture @pytest.fixture
def mock_subentry_data() -> dict[str, Any]: def mock_conversation_subentry_data() -> dict[str, Any]:
"""Mock subentry data.""" """Mock subentry data."""
return {} return {}
@pytest.fixture @pytest.fixture
def mock_config_entry( def mock_config_entry(
hass: HomeAssistant, mock_subentry_data: dict[str, Any] hass: HomeAssistant, mock_conversation_subentry_data: dict[str, Any]
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Mock a config entry.""" """Mock a config entry."""
entry = MockConfigEntry( entry = MockConfigEntry(
@ -36,13 +53,20 @@ def mock_config_entry(
"api_key": "bla", "api_key": "bla",
}, },
version=2, version=2,
minor_version=3,
subentries_data=[ subentries_data=[
ConfigSubentryData( ConfigSubentryData(
data=mock_subentry_data, data=mock_conversation_subentry_data,
subentry_type="conversation", subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME, title=DEFAULT_CONVERSATION_NAME,
unique_id=None, unique_id=None,
) ),
ConfigSubentryData(
data=RECOMMENDED_AI_TASK_OPTIONS,
subentry_type="ai_task_data",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
], ],
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -91,3 +115,94 @@ async def mock_init_component(
async def setup_ha(hass: HomeAssistant) -> None: async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant.""" """Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools", []),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
sequence_number=0,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
sequence_number=0,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
sequence_number=0,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
sequence_number=0,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
sequence_number=0,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_devices[mock_subentry_data0] # name: test_devices[mock_conversation_subentry_data0]
DeviceRegistryEntrySnapshot({ DeviceRegistryEntrySnapshot({
'area_id': None, 'area_id': None,
'config_entries': <ANY>, 'config_entries': <ANY>,
@ -26,7 +26,7 @@
'via_device_id': None, 'via_device_id': None,
}) })
# --- # ---
# name: test_devices[mock_subentry_data1] # name: test_devices[mock_conversation_subentry_data1]
DeviceRegistryEntrySnapshot({ DeviceRegistryEntrySnapshot({
'area_id': None, 'area_id': None,
'config_entries': <ANY>, 'config_entries': <ANY>,

View File

@ -0,0 +1,124 @@
"""Test AI Task platform of OpenAI Conversation integration."""
from unittest.mock import AsyncMock
import pytest
import voluptuous as vol
from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from . import create_message_item
from tests.common import MockConfigEntry
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation."""
entity_id = "ai_task.openai_ai_task"
# Ensure entity is linked to the subentry
entity_entry = entity_registry.async_get(entity_id)
ai_task_entry = next(
iter(
entry
for entry in mock_config_entry.subentries.values()
if entry.subentry_type == "ai_task_data"
)
)
assert entity_entry is not None
assert entity_entry.config_entry_id == mock_config_entry.entry_id
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
# Mock the OpenAI response stream
mock_create_stream.return_value = [
create_message_item(id="msg_A", text="The test data", output_index=0)
]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
)
assert result.data == "The test data"
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task structured data generation."""
# Mock the OpenAI response stream with JSON data
mock_create_stream.return_value = [
create_message_item(
id="msg_A", text='{"characters": ["Mario", "Luigi"]}', output_index=0
)
]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.openai_ai_task",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)
assert result.data == {"characters": ["Mario", "Luigi"]}
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_invalid_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task with invalid JSON response."""
# Mock the OpenAI response stream with invalid JSON
mock_create_stream.return_value = [
create_message_item(id="msg_A", text="INVALID JSON RESPONSE", output_index=0)
]
with pytest.raises(
HomeAssistantError, match="Error with OpenAI structured response"
):
await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.openai_ai_task",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)

View File

@ -8,7 +8,9 @@ from openai.types.responses import Response, ResponseOutputMessage, ResponseOutp
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.openai_conversation.config_flow import RECOMMENDED_OPTIONS from homeassistant.components.openai_conversation.config_flow import (
RECOMMENDED_CONVERSATION_OPTIONS,
)
from homeassistant.components.openai_conversation.const import ( from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
@ -24,8 +26,10 @@ from homeassistant.components.openai_conversation.const import (
CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
@ -77,10 +81,16 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["subentries"] == [ assert result2["subentries"] == [
{ {
"subentry_type": "conversation", "subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS, "data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME, "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
} },
{
"subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
] ]
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -131,14 +141,14 @@ async def test_creating_conversation_subentry(
result2 = await hass.config_entries.subentries.async_configure( result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"], result["flow_id"],
{"name": "My Custom Agent", **RECOMMENDED_OPTIONS}, {"name": "My Custom Agent", **RECOMMENDED_CONVERSATION_OPTIONS},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "My Custom Agent" assert result2["title"] == "My Custom Agent"
processed_options = RECOMMENDED_OPTIONS.copy() processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip() processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options assert result2["data"] == processed_options
@ -709,3 +719,110 @@ async def test_subentry_web_search_user_location(
CONF_WEB_SEARCH_COUNTRY: "US", CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles", CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
} }
async def test_creating_ai_task_subentry(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test creating an AI task subentry."""
old_subentries = set(mock_config_entry.subentries)
# Original conversation + original ai_task
assert len(mock_config_entry.subentries) == 2
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "init"
assert not result.get("errors")
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
"name": "Custom AI Task",
CONF_RECOMMENDED: True,
},
)
await hass.async_block_till_done()
assert result2.get("type") is FlowResultType.CREATE_ENTRY
assert result2.get("title") == "Custom AI Task"
assert result2.get("data") == {
CONF_RECOMMENDED: True,
}
assert (
len(mock_config_entry.subentries) == 3
) # Original conversation + original ai_task + new ai_task
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task_data"
assert new_subentry.title == "Custom AI Task"
async def test_ai_task_subentry_not_loaded(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating an AI task subentry when entry is not loaded."""
# Don't call mock_init_component to simulate not loaded state
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result.get("type") is FlowResultType.ABORT
assert result.get("reason") == "entry_not_loaded"
async def test_creating_ai_task_subentry_advanced(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test creating an AI task subentry with advanced settings."""
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "init"
# Go to advanced settings
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
"name": "Advanced AI Task",
CONF_RECOMMENDED: False,
},
)
assert result2.get("type") is FlowResultType.FORM
assert result2.get("step_id") == "advanced"
# Configure advanced settings
result3 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
CONF_CHAT_MODEL: "gpt-4o",
CONF_MAX_TOKENS: 200,
CONF_TEMPERATURE: 0.5,
CONF_TOP_P: 0.9,
},
)
assert result3.get("type") is FlowResultType.CREATE_ENTRY
assert result3.get("title") == "Advanced AI Task"
assert result3.get("data") == {
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: "gpt-4o",
CONF_MAX_TOKENS: 200,
CONF_TEMPERATURE: 0.5,
CONF_TOP_P: 0.9,
}

View File

@ -1,41 +1,15 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import httpx import httpx
from openai import AuthenticationError, RateLimitError from openai import AuthenticationError, RateLimitError
from openai.types import ResponseFormatText
from openai.types.responses import ( from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseError, ResponseError,
ResponseErrorEvent, ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseStreamEvent, ResponseStreamEvent,
ResponseTextConfig,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
) )
from openai.types.responses.response import IncompleteDetails from openai.types.responses.response import IncompleteDetails
from openai.types.responses.response_function_web_search import ActionSearch
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -55,6 +29,13 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import (
create_function_tool_call_item,
create_message_item,
create_reasoning_item,
create_web_search_item,
)
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.conversation import ( from tests.components.conversation import (
MockChatLog, MockChatLog,
@ -62,97 +43,6 @@ from tests.components.conversation import (
) )
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools"),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
sequence_number=0,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
sequence_number=0,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
sequence_number=0,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
sequence_number=0,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
sequence_number=0,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create
async def test_entity( async def test_entity(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -347,225 +237,6 @@ async def test_conversation_agent(
assert agent.supported_languages == "*" assert agent.supported_languages == "*"
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
sequence_number=0,
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]:
"""Create a function tool call item."""
if isinstance(arguments, str):
arguments = [arguments]
events = [
ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="",
call_id=call_id,
name=name,
type="function_call",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
)
]
events.extend(
ResponseFunctionCallArgumentsDeltaEvent(
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.delta",
)
for delta in arguments
)
events.append(
ResponseFunctionCallArgumentsDoneEvent(
arguments="".join(arguments),
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.done",
)
)
events.append(
ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="".join(arguments),
call_id=call_id,
name=name,
type="function_call",
status="completed",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
)
)
return events
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a reasoning item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAA",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseOutputItemDoneEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAABBB",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a web search call item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseFunctionWebSearch(
id=id,
status="in_progress",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseWebSearchCallInProgressEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.in_progress",
),
ResponseWebSearchCallSearchingEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.searching",
),
ResponseWebSearchCallCompletedEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseFunctionWebSearch(
id=id,
status="completed",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
async def test_function_call( async def test_function_call(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry_with_reasoning_model: MockConfigEntry, mock_config_entry_with_reasoning_model: MockConfigEntry,

View File

@ -0,0 +1,77 @@
"""Tests for the OpenAI Conversation entity."""
import voluptuous as vol
from homeassistant.components.openai_conversation.entity import (
_format_structured_output,
)
from homeassistant.helpers import selector
async def test_format_structured_output() -> None:
"""Test the format_structured_output function."""
schema = vol.Schema(
{
vol.Required("name"): selector.TextSelector(),
vol.Optional("age"): selector.NumberSelector(
config=selector.NumberSelectorConfig(
min=0,
max=120,
),
),
vol.Required("stuff"): selector.ObjectSelector(
{
"multiple": True,
"fields": {
"item_name": {
"selector": {"text": None},
},
"item_value": {
"selector": {"text": None},
},
},
}
),
}
)
assert _format_structured_output(schema, None) == {
"additionalProperties": False,
"properties": {
"age": {
"maximum": 120.0,
"minimum": 0.0,
"type": [
"number",
"null",
],
},
"name": {
"type": "string",
},
"stuff": {
"items": {
"properties": {
"item_name": {
"type": ["string", "null"],
},
"item_value": {
"type": ["string", "null"],
},
},
"required": [
"item_name",
"item_value",
],
"type": "object",
},
"type": "array",
},
},
"required": [
"name",
"stuff",
"age",
],
"strict": True,
"type": "object",
}

View File

@ -17,7 +17,10 @@ from syrupy.assertion import SnapshotAssertion
from syrupy.filters import props from syrupy.filters import props
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
from homeassistant.components.openai_conversation.const import DOMAIN from homeassistant.components.openai_conversation.const import (
DEFAULT_AI_TASK_NAME,
DOMAIN,
)
from homeassistant.config_entries import ConfigSubentryData from homeassistant.config_entries import ConfigSubentryData
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
@ -534,7 +537,7 @@ async def test_generate_content_service_error(
) )
async def test_migration_from_v1_to_v2( async def test_migration_from_v1(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
@ -582,17 +585,33 @@ async def test_migration_from_v1_to_v2(
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_config_entry.version == 2 assert mock_config_entry.version == 2
assert mock_config_entry.minor_version == 2 assert mock_config_entry.minor_version == 3
assert mock_config_entry.data == {"api_key": "1234"} assert mock_config_entry.data == {"api_key": "1234"}
assert mock_config_entry.options == {} assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 1 assert len(mock_config_entry.subentries) == 2
subentry = next(iter(mock_config_entry.subentries.values())) # Find the conversation subentry
assert subentry.unique_id is None conversation_subentry = None
assert subentry.title == "ChatGPT" ai_task_subentry = None
assert subentry.subentry_type == "conversation" for subentry in mock_config_entry.subentries.values():
assert subentry.data == OPTIONS if subentry.subentry_type == "conversation":
conversation_subentry = subentry
elif subentry.subentry_type == "ai_task_data":
ai_task_subentry = subentry
assert conversation_subentry is not None
assert conversation_subentry.unique_id is None
assert conversation_subentry.title == "ChatGPT"
assert conversation_subentry.subentry_type == "conversation"
assert conversation_subentry.data == OPTIONS
assert ai_task_subentry is not None
assert ai_task_subentry.unique_id is None
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
assert ai_task_subentry.subentry_type == "ai_task_data"
# Use conversation subentry for the rest of the assertions
subentry = conversation_subentry
migrated_entity = entity_registry.async_get(entity.entity_id) migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity is not None assert migrated_entity is not None
@ -617,12 +636,12 @@ async def test_migration_from_v1_to_v2(
} }
async def test_migration_from_v1_to_v2_with_multiple_keys( async def test_migration_from_v1_with_multiple_keys(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test migration from version 1 to version 2 with different API keys.""" """Test migration from version 1 with different API keys."""
# Create two v1 config entries with different API keys # Create two v1 config entries with different API keys
options = { options = {
"recommended": True, "recommended": True,
@ -695,28 +714,38 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
for idx, entry in enumerate(entries): for idx, entry in enumerate(entries):
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 3
assert not entry.options assert not entry.options
assert len(entry.subentries) == 1 assert len(entry.subentries) == 2
subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation" conversation_subentry = None
assert subentry.data == options for subentry in entry.subentries.values():
assert subentry.title == f"ChatGPT {idx + 1}" if subentry.subentry_type == "conversation":
conversation_subentry = subentry
break
assert conversation_subentry is not None
assert conversation_subentry.subentry_type == "conversation"
assert conversation_subentry.data == options
assert conversation_subentry.title == f"ChatGPT {idx + 1}"
# Use conversation subentry for device assertions
subentry = conversation_subentry
dev = device_registry.async_get_device( dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} identifiers={(DOMAIN, subentry.subentry_id)}
) )
assert dev is not None assert dev is not None
assert dev.config_entries == {entry.entry_id} assert dev.config_entries == {entry.entry_id}
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}} assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
async def test_migration_from_v1_to_v2_with_same_keys( async def test_migration_from_v1_with_same_keys(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test migration from version 1 to version 2 with same API keys consolidates entries.""" """Test migration from version 1 with same API keys consolidates entries."""
# Create two v1 config entries with the same API key # Create two v1 config entries with the same API key
options = { options = {
"recommended": True, "recommended": True,
@ -790,17 +819,28 @@ async def test_migration_from_v1_to_v2_with_same_keys(
entry = entries[0] entry = entries[0]
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 3
assert not entry.options assert not entry.options
assert len(entry.subentries) == 2 # Two subentries from the two original entries assert (
len(entry.subentries) == 3
) # Two conversation subentries + one AI task subentry
# Check both subentries exist with correct data # Check both conversation subentries exist with correct data
subentries = list(entry.subentries.values()) conversation_subentries = [
titles = [sub.title for sub in subentries] sub for sub in entry.subentries.values() if sub.subentry_type == "conversation"
]
ai_task_subentries = [
sub for sub in entry.subentries.values() if sub.subentry_type == "ai_task_data"
]
assert len(conversation_subentries) == 2
assert len(ai_task_subentries) == 1
titles = [sub.title for sub in conversation_subentries]
assert "ChatGPT" in titles assert "ChatGPT" in titles
assert "ChatGPT 2" in titles assert "ChatGPT 2" in titles
for subentry in subentries: for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == options assert subentry.data == options
@ -815,12 +855,12 @@ async def test_migration_from_v1_to_v2_with_same_keys(
} }
async def test_migration_from_v2_1_to_v2_2( async def test_migration_from_v2_1(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test migration from version 2.1 to version 2.2. """Test migration from version 2.1.
This tests we clean up the broken migration in Home Assistant Core This tests we clean up the broken migration in Home Assistant Core
2025.7.0b0-2025.7.0b1: 2025.7.0b0-2025.7.0b1:
@ -913,16 +953,22 @@ async def test_migration_from_v2_1_to_v2_2(
assert len(entries) == 1 assert len(entries) == 1
entry = entries[0] entry = entries[0]
assert entry.version == 2 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 3
assert not entry.options assert not entry.options
assert entry.title == "ChatGPT" assert entry.title == "ChatGPT"
assert len(entry.subentries) == 2 assert len(entry.subentries) == 3 # 2 conversation + 1 AI task
conversation_subentries = [ conversation_subentries = [
subentry subentry
for subentry in entry.subentries.values() for subentry in entry.subentries.values()
if subentry.subentry_type == "conversation" if subentry.subentry_type == "conversation"
] ]
ai_task_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "ai_task_data"
]
assert len(conversation_subentries) == 2 assert len(conversation_subentries) == 2
assert len(ai_task_subentries) == 1
for subentry in conversation_subentries: for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == options assert subentry.data == options
@ -972,7 +1018,9 @@ async def test_migration_from_v2_1_to_v2_2(
} }
@pytest.mark.parametrize("mock_subentry_data", [{}, {CONF_CHAT_MODEL: "gpt-1o"}]) @pytest.mark.parametrize(
"mock_conversation_subentry_data", [{}, {CONF_CHAT_MODEL: "gpt-1o"}]
)
async def test_devices( async def test_devices(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -980,12 +1028,89 @@ async def test_devices(
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Assert exception when invalid config entry is provided.""" """Test devices are correctly created for subentries."""
devices = dr.async_entries_for_config_entry( devices = dr.async_entries_for_config_entry(
device_registry, mock_config_entry.entry_id device_registry, mock_config_entry.entry_id
) )
assert len(devices) == 1 assert len(devices) == 2 # One for conversation, one for AI task
# Use the first device for snapshot comparison
device = devices[0] device = devices[0]
assert device == snapshot(exclude=props("identifiers")) assert device == snapshot(exclude=props("identifiers"))
subentry = next(iter(mock_config_entry.subentries.values())) # Verify the device has identifiers matching one of the subentries
assert device.identifiers == {(DOMAIN, subentry.subentry_id)} expected_identifiers = [
{(DOMAIN, subentry.subentry_id)}
for subentry in mock_config_entry.subentries.values()
]
assert device.identifiers in expected_identifiers
async def test_migration_from_v2_2(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 2.2."""
# Create a v2.2 config entry with a conversation subentry
options = {
"recommended": True,
"llm_hass_api": ["assist"],
"prompt": "You are a helpful assistant",
"chat_model": "gpt-4o-mini",
}
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"api_key": "1234"},
entry_id="mock_entry_id",
version=2,
minor_version=2,
subentries_data=[
ConfigSubentryData(
data=options,
subentry_id="mock_id_1",
subentry_type="conversation",
title="ChatGPT",
unique_id=None,
),
],
title="ChatGPT",
)
mock_config_entry.add_to_hass(hass)
# Run migration
with patch(
"homeassistant.components.openai_conversation.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.version == 2
assert entry.minor_version == 3
assert not entry.options
assert entry.title == "ChatGPT"
assert len(entry.subentries) == 2
# Check conversation subentry is still there
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "conversation"
]
assert len(conversation_subentries) == 1
conversation_subentry = conversation_subentries[0]
assert conversation_subentry.data == options
# Check AI Task subentry was added
ai_task_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "ai_task_data"
]
assert len(ai_task_subentries) == 1
ai_task_subentry = ai_task_subentries[0]
assert ai_task_subentry.data == {"recommended": True}
assert ai_task_subentry.title == "OpenAI AI Task"