Add OpenAI AI Task entity (#148295)

This commit is contained in:
Paulus Schoutsen 2025-07-10 23:08:56 +02:00 committed by GitHub
parent f0a636949a
commit 0e09a47476
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1152 additions and 463 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from pathlib import Path
from types import MappingProxyType
import openai
from openai.types.images_response import ImagesResponse
@ -45,9 +46,11 @@ from .const import (
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_AI_TASK_NAME,
DEFAULT_NAME,
DOMAIN,
LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
@ -59,7 +62,7 @@ from .entity import async_prepare_files_for_prompt
SERVICE_GENERATE_IMAGE = "generate_image"
SERVICE_GENERATE_CONTENT = "generate_content"
PLATFORMS = (Platform.CONVERSATION,)
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
@ -153,7 +156,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
EasyInputMessageParam(type="message", role="user", content=content)
]
try:
model_args = {
"model": model,
"input": messages,
@ -175,6 +177,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
)
}
try:
response: Response = await client.responses.create(**model_args)
except openai.OpenAIError as err:
@ -361,6 +364,18 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
hass.config_entries.async_update_entry(entry, minor_version=2)
if entry.version == 2 and entry.minor_version == 2:
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
subentry_type="ai_task_data",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
)
hass.config_entries.async_update_entry(entry, minor_version=3)
LOGGER.debug(
"Migration to version %s:%s successful", entry.version, entry.minor_version
)

View File

@ -0,0 +1,77 @@
"""AI Task integration for OpenAI."""
from __future__ import annotations
from json import JSONDecodeError
import logging
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .entity import OpenAIBaseLLMEntity
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue
async_add_entities(
[OpenAITaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class OpenAITaskEntity(
ai_task.AITaskEntity,
OpenAIBaseLLMEntity,
):
"""OpenAI AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.name, task.structure)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)
text = chat_log.content[-1].content or ""
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with OpenAI structured response") from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)

View File

@ -55,9 +55,12 @@ from .const import (
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
@ -77,12 +80,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
}
)
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
@ -99,7 +96,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OpenAI Conversation."""
VERSION = 2
MINOR_VERSION = 2
MINOR_VERSION = 3
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@ -129,10 +126,16 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
subentries=[
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
],
)
@ -146,11 +149,14 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
return {
"conversation": OpenAISubentryFlowHandler,
"ai_task_data": OpenAISubentryFlowHandler,
}
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
class OpenAISubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing OpenAI subentries."""
last_rendered_recommended = False
options: dict[str, Any]
@ -164,7 +170,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Add a subentry."""
self.options = RECOMMENDED_OPTIONS.copy()
if self._subentry_type == "ai_task_data":
self.options = RECOMMENDED_AI_TASK_OPTIONS.copy()
else:
self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
return await self.async_step_init()
async def async_step_reconfigure(
@ -181,6 +190,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
# abort if entry is not loaded
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
options = self.options
hass_apis: list[SelectOptionDict] = [
@ -198,10 +208,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
step_schema: VolDictType = {}
if self._is_new:
step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = (
str
)
if self._subentry_type == "ai_task_data":
default_name = DEFAULT_AI_TASK_NAME
else:
default_name = DEFAULT_CONVERSATION_NAME
step_schema[vol.Required(CONF_NAME, default=default_name)] = str
if self._subentry_type == "conversation":
step_schema.update(
{
vol.Optional(
@ -215,12 +228,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True)
),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
)
step_schema[
vol.Required(CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False))
] = bool
if user_input is not None:
if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None)
@ -320,7 +334,9 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
elif CONF_REASONING_EFFORT in options:
options.pop(CONF_REASONING_EFFORT)
if not model.startswith(tuple(UNSUPPORTED_WEB_SEARCH_MODELS)):
if self._subentry_type == "conversation" and not model.startswith(
tuple(UNSUPPORTED_WEB_SEARCH_MODELS)
):
step_schema.update(
{
vol.Optional(
@ -362,7 +378,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
if not step_schema:
if self._is_new:
return self.async_create_entry(
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
title=options.pop(CONF_NAME),
data=options,
)
return self.async_update_and_abort(
@ -384,7 +400,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
options.update(user_input)
if self._is_new:
return self.async_create_entry(
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
title=options.pop(CONF_NAME),
data=options,
)
return self.async_update_and_abort(

View File

@ -2,10 +2,14 @@
import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "openai_conversation"
LOGGER: logging.Logger = logging.getLogger(__package__)
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
DEFAULT_NAME = "OpenAI Conversation"
CONF_CHAT_MODEL = "chat_model"
@ -51,3 +55,12 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [
"o1",
"o3-mini",
]
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
RECOMMENDED_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True,
}

View File

@ -39,6 +39,7 @@ from openai.types.responses import (
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.web_search_tool_param import UserLocation
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import conversation
@ -47,6 +48,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
from homeassistant.util import slugify
from .const import (
CONF_CHAT_MODEL,
@ -79,6 +81,47 @@ if TYPE_CHECKING:
MAX_TOOL_ITERATIONS = 10
def _adjust_schema(schema: dict[str, Any]) -> None:
"""Adjust the schema to be compatible with OpenAI API."""
if schema["type"] == "object":
if "properties" not in schema:
return
if "required" not in schema:
schema["required"] = []
# Ensure all properties are required
for prop, prop_info in schema["properties"].items():
_adjust_schema(prop_info)
if prop not in schema["required"]:
prop_info["type"] = [prop_info["type"], "null"]
schema["required"].append(prop)
elif schema["type"] == "array":
if "items" not in schema:
return
_adjust_schema(schema["items"])
def _format_structured_output(
schema: vol.Schema, llm_api: llm.APIInstance | None
) -> dict[str, Any]:
"""Format the schema to be compatible with OpenAI API."""
result: dict[str, Any] = convert(
schema,
custom_serializer=(
llm_api.custom_serializer if llm_api else llm.selector_serializer
),
)
_adjust_schema(result)
result["strict"] = True
result["additionalProperties"] = False
return result
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> FunctionToolParam:
@ -243,6 +286,8 @@ class OpenAIBaseLLMEntity(Entity):
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
structure_name: str | None = None,
structure: vol.Schema | None = None,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
@ -273,23 +318,10 @@ class OpenAIBaseLLMEntity(Entity):
tools = []
tools.append(web_search)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args = {
"model": model,
"input": messages,
"max_output_tokens": options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
"input": [],
"max_output_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
@ -299,13 +331,34 @@ class OpenAIBaseLLMEntity(Entity):
if tools:
model_args["tools"] = tools
if model.startswith("o"):
if model_args["model"].startswith("o"):
model_args["reasoning"] = {
"effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
}
model_args["include"] = ["reasoning.encrypted_content"]
else:
model_args["store"] = False
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
if structure and structure_name:
model_args["text"] = {
"format": {
"type": "json_schema",
"name": slugify(structure_name),
"schema": _format_structured_output(structure, chat_log.llm_api),
},
}
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args["input"] = messages
try:
result = await client.responses.create(**model_args)

View File

@ -68,6 +68,52 @@
"error": {
"model_not_supported": "This model is not supported, please select a different model"
}
},
"ai_task_data": {
"initiate_flow": {
"user": "Add Generate data with AI service",
"reconfigure": "Reconfigure Generate data with AI service"
},
"entry_type": "Generate data with AI service",
"step": {
"init": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::openai_conversation::config_subentries::conversation::step::init::data::recommended%]"
}
},
"advanced": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::title%]",
"data": {
"chat_model": "[%key:common::generic::model%]",
"max_tokens": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::max_tokens%]",
"temperature": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::temperature%]",
"top_p": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::top_p%]"
}
},
"model": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::model::title%]",
"data": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::user_location%]"
},
"data_description": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::user_location%]"
}
}
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
"entry_not_loaded": "[%key:component::openai_conversation::config_subentries::conversation::abort::entry_not_loaded%]"
},
"error": {
"model_not_supported": "[%key:component::openai_conversation::config_subentries::conversation::error::model_not_supported%]"
}
}
},
"selector": {

View File

@ -1 +1,241 @@
"""Tests for the OpenAI Conversation integration."""
from openai.types.responses import (
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses.response_function_web_search import ActionSearch
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
sequence_number=0,
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]:
"""Create a function tool call item."""
if isinstance(arguments, str):
arguments = [arguments]
events = [
ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="",
call_id=call_id,
name=name,
type="function_call",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
)
]
events.extend(
ResponseFunctionCallArgumentsDeltaEvent(
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.delta",
)
for delta in arguments
)
events.append(
ResponseFunctionCallArgumentsDoneEvent(
arguments="".join(arguments),
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.done",
)
)
events.append(
ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="".join(arguments),
call_id=call_id,
name=name,
type="function_call",
status="completed",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
)
)
return events
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a reasoning item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAA",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseOutputItemDoneEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAABBB",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a web search call item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseFunctionWebSearch(
id=id,
status="in_progress",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseWebSearchCallInProgressEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.in_progress",
),
ResponseWebSearchCallSearchingEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.searching",
),
ResponseWebSearchCallCompletedEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseFunctionWebSearch(
id=id,
status="completed",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]

View File

@ -1,13 +1,30 @@
"""Tests helpers."""
from collections.abc import Generator
from typing import Any
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemDoneEvent,
ResponseTextConfig,
)
from openai.types.responses.response import IncompleteDetails
import pytest
from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
RECOMMENDED_AI_TASK_OPTIONS,
)
from homeassistant.config_entries import ConfigSubentryData
from homeassistant.const import CONF_LLM_HASS_API
@ -19,14 +36,14 @@ from tests.common import MockConfigEntry
@pytest.fixture
def mock_subentry_data() -> dict[str, Any]:
def mock_conversation_subentry_data() -> dict[str, Any]:
"""Mock subentry data."""
return {}
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, mock_subentry_data: dict[str, Any]
hass: HomeAssistant, mock_conversation_subentry_data: dict[str, Any]
) -> MockConfigEntry:
"""Mock a config entry."""
entry = MockConfigEntry(
@ -36,13 +53,20 @@ def mock_config_entry(
"api_key": "bla",
},
version=2,
minor_version=3,
subentries_data=[
ConfigSubentryData(
data=mock_subentry_data,
data=mock_conversation_subentry_data,
subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME,
unique_id=None,
)
),
ConfigSubentryData(
data=RECOMMENDED_AI_TASK_OPTIONS,
subentry_type="ai_task_data",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
],
)
entry.add_to_hass(hass)
@ -91,3 +115,94 @@ async def mock_init_component(
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools", []),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
sequence_number=0,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
sequence_number=0,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
sequence_number=0,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
sequence_number=0,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
sequence_number=0,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_devices[mock_subentry_data0]
# name: test_devices[mock_conversation_subentry_data0]
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,
@ -26,7 +26,7 @@
'via_device_id': None,
})
# ---
# name: test_devices[mock_subentry_data1]
# name: test_devices[mock_conversation_subentry_data1]
DeviceRegistryEntrySnapshot({
'area_id': None,
'config_entries': <ANY>,

View File

@ -0,0 +1,124 @@
"""Test AI Task platform of OpenAI Conversation integration."""
from unittest.mock import AsyncMock
import pytest
import voluptuous as vol
from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from . import create_message_item
from tests.common import MockConfigEntry
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation."""
entity_id = "ai_task.openai_ai_task"
# Ensure entity is linked to the subentry
entity_entry = entity_registry.async_get(entity_id)
ai_task_entry = next(
iter(
entry
for entry in mock_config_entry.subentries.values()
if entry.subentry_type == "ai_task_data"
)
)
assert entity_entry is not None
assert entity_entry.config_entry_id == mock_config_entry.entry_id
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
# Mock the OpenAI response stream
mock_create_stream.return_value = [
create_message_item(id="msg_A", text="The test data", output_index=0)
]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
)
assert result.data == "The test data"
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task structured data generation."""
# Mock the OpenAI response stream with JSON data
mock_create_stream.return_value = [
create_message_item(
id="msg_A", text='{"characters": ["Mario", "Luigi"]}', output_index=0
)
]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.openai_ai_task",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)
assert result.data == {"characters": ["Mario", "Luigi"]}
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_invalid_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task with invalid JSON response."""
# Mock the OpenAI response stream with invalid JSON
mock_create_stream.return_value = [
create_message_item(id="msg_A", text="INVALID JSON RESPONSE", output_index=0)
]
with pytest.raises(
HomeAssistantError, match="Error with OpenAI structured response"
):
await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.openai_ai_task",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)

View File

@ -8,7 +8,9 @@ from openai.types.responses import Response, ResponseOutputMessage, ResponseOutp
import pytest
from homeassistant import config_entries
from homeassistant.components.openai_conversation.config_flow import RECOMMENDED_OPTIONS
from homeassistant.components.openai_conversation.config_flow import (
RECOMMENDED_CONVERSATION_OPTIONS,
)
from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@ -24,8 +26,10 @@ from homeassistant.components.openai_conversation.const import (
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P,
@ -77,10 +81,16 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["subentries"] == [
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
]
assert len(mock_setup_entry.mock_calls) == 1
@ -131,14 +141,14 @@ async def test_creating_conversation_subentry(
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{"name": "My Custom Agent", **RECOMMENDED_OPTIONS},
{"name": "My Custom Agent", **RECOMMENDED_CONVERSATION_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "My Custom Agent"
processed_options = RECOMMENDED_OPTIONS.copy()
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options
@ -709,3 +719,110 @@ async def test_subentry_web_search_user_location(
CONF_WEB_SEARCH_COUNTRY: "US",
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
}
async def test_creating_ai_task_subentry(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test creating an AI task subentry."""
old_subentries = set(mock_config_entry.subentries)
# Original conversation + original ai_task
assert len(mock_config_entry.subentries) == 2
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "init"
assert not result.get("errors")
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
"name": "Custom AI Task",
CONF_RECOMMENDED: True,
},
)
await hass.async_block_till_done()
assert result2.get("type") is FlowResultType.CREATE_ENTRY
assert result2.get("title") == "Custom AI Task"
assert result2.get("data") == {
CONF_RECOMMENDED: True,
}
assert (
len(mock_config_entry.subentries) == 3
) # Original conversation + original ai_task + new ai_task
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task_data"
assert new_subentry.title == "Custom AI Task"
async def test_ai_task_subentry_not_loaded(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating an AI task subentry when entry is not loaded."""
# Don't call mock_init_component to simulate not loaded state
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result.get("type") is FlowResultType.ABORT
assert result.get("reason") == "entry_not_loaded"
async def test_creating_ai_task_subentry_advanced(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test creating an AI task subentry with advanced settings."""
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
)
assert result.get("type") is FlowResultType.FORM
assert result.get("step_id") == "init"
# Go to advanced settings
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
"name": "Advanced AI Task",
CONF_RECOMMENDED: False,
},
)
assert result2.get("type") is FlowResultType.FORM
assert result2.get("step_id") == "advanced"
# Configure advanced settings
result3 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
CONF_CHAT_MODEL: "gpt-4o",
CONF_MAX_TOKENS: 200,
CONF_TEMPERATURE: 0.5,
CONF_TOP_P: 0.9,
},
)
assert result3.get("type") is FlowResultType.CREATE_ENTRY
assert result3.get("title") == "Advanced AI Task"
assert result3.get("data") == {
CONF_RECOMMENDED: False,
CONF_CHAT_MODEL: "gpt-4o",
CONF_MAX_TOKENS: 200,
CONF_TEMPERATURE: 0.5,
CONF_TOP_P: 0.9,
}

View File

@ -1,41 +1,15 @@
"""Tests for the OpenAI integration."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
import httpx
from openai import AuthenticationError, RateLimitError
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseStreamEvent,
ResponseTextConfig,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses.response import IncompleteDetails
from openai.types.responses.response_function_web_search import ActionSearch
import pytest
from syrupy.assertion import SnapshotAssertion
@ -55,6 +29,13 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from . import (
create_function_tool_call_item,
create_message_item,
create_reasoning_item,
create_web_search_item,
)
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
@ -62,97 +43,6 @@ from tests.components.conversation import (
)
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools"),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
sequence_number=0,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
sequence_number=0,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
sequence_number=0,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
sequence_number=0,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
sequence_number=0,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create
async def test_entity(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@ -347,225 +237,6 @@ async def test_conversation_agent(
assert agent.supported_languages == "*"
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
sequence_number=0,
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
sequence_number=0,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]:
"""Create a function tool call item."""
if isinstance(arguments, str):
arguments = [arguments]
events = [
ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="",
call_id=call_id,
name=name,
type="function_call",
status="in_progress",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
)
]
events.extend(
ResponseFunctionCallArgumentsDeltaEvent(
delta=delta,
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.delta",
)
for delta in arguments
)
events.append(
ResponseFunctionCallArgumentsDoneEvent(
arguments="".join(arguments),
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.function_call_arguments.done",
)
)
events.append(
ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=id,
arguments="".join(arguments),
call_id=call_id,
name=name,
type="function_call",
status="completed",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
)
)
return events
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a reasoning item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAA",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseOutputItemDoneEvent(
item=ResponseReasoningItem(
id=id,
summary=[],
type="reasoning",
status=None,
encrypted_content="AAABBB",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
"""Create a web search call item."""
return [
ResponseOutputItemAddedEvent(
item=ResponseFunctionWebSearch(
id=id,
status="in_progress",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.added",
),
ResponseWebSearchCallInProgressEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.in_progress",
),
ResponseWebSearchCallSearchingEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.searching",
),
ResponseWebSearchCallCompletedEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.web_search_call.completed",
),
ResponseOutputItemDoneEvent(
item=ResponseFunctionWebSearch(
id=id,
status="completed",
action=ActionSearch(query="query", type="search"),
type="web_search_call",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]
async def test_function_call(
hass: HomeAssistant,
mock_config_entry_with_reasoning_model: MockConfigEntry,

View File

@ -0,0 +1,77 @@
"""Tests for the OpenAI Conversation entity."""
import voluptuous as vol
from homeassistant.components.openai_conversation.entity import (
_format_structured_output,
)
from homeassistant.helpers import selector
async def test_format_structured_output() -> None:
"""Test the format_structured_output function."""
schema = vol.Schema(
{
vol.Required("name"): selector.TextSelector(),
vol.Optional("age"): selector.NumberSelector(
config=selector.NumberSelectorConfig(
min=0,
max=120,
),
),
vol.Required("stuff"): selector.ObjectSelector(
{
"multiple": True,
"fields": {
"item_name": {
"selector": {"text": None},
},
"item_value": {
"selector": {"text": None},
},
},
}
),
}
)
assert _format_structured_output(schema, None) == {
"additionalProperties": False,
"properties": {
"age": {
"maximum": 120.0,
"minimum": 0.0,
"type": [
"number",
"null",
],
},
"name": {
"type": "string",
},
"stuff": {
"items": {
"properties": {
"item_name": {
"type": ["string", "null"],
},
"item_value": {
"type": ["string", "null"],
},
},
"required": [
"item_name",
"item_value",
],
"type": "object",
},
"type": "array",
},
},
"required": [
"name",
"stuff",
"age",
],
"strict": True,
"type": "object",
}

View File

@ -17,7 +17,10 @@ from syrupy.assertion import SnapshotAssertion
from syrupy.filters import props
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
from homeassistant.components.openai_conversation.const import DOMAIN
from homeassistant.components.openai_conversation.const import (
DEFAULT_AI_TASK_NAME,
DOMAIN,
)
from homeassistant.config_entries import ConfigSubentryData
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
@ -534,7 +537,7 @@ async def test_generate_content_service_error(
)
async def test_migration_from_v1_to_v2(
async def test_migration_from_v1(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
@ -582,17 +585,33 @@ async def test_migration_from_v1_to_v2(
await hass.async_block_till_done()
assert mock_config_entry.version == 2
assert mock_config_entry.minor_version == 2
assert mock_config_entry.minor_version == 3
assert mock_config_entry.data == {"api_key": "1234"}
assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 1
assert len(mock_config_entry.subentries) == 2
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.unique_id is None
assert subentry.title == "ChatGPT"
assert subentry.subentry_type == "conversation"
assert subentry.data == OPTIONS
# Find the conversation subentry
conversation_subentry = None
ai_task_subentry = None
for subentry in mock_config_entry.subentries.values():
if subentry.subentry_type == "conversation":
conversation_subentry = subentry
elif subentry.subentry_type == "ai_task_data":
ai_task_subentry = subentry
assert conversation_subentry is not None
assert conversation_subentry.unique_id is None
assert conversation_subentry.title == "ChatGPT"
assert conversation_subentry.subentry_type == "conversation"
assert conversation_subentry.data == OPTIONS
assert ai_task_subentry is not None
assert ai_task_subentry.unique_id is None
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
assert ai_task_subentry.subentry_type == "ai_task_data"
# Use conversation subentry for the rest of the assertions
subentry = conversation_subentry
migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity is not None
@ -617,12 +636,12 @@ async def test_migration_from_v1_to_v2(
}
async def test_migration_from_v1_to_v2_with_multiple_keys(
async def test_migration_from_v1_with_multiple_keys(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with different API keys."""
"""Test migration from version 1 with different API keys."""
# Create two v1 config entries with different API keys
options = {
"recommended": True,
@ -695,28 +714,38 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
for idx, entry in enumerate(entries):
assert entry.version == 2
assert entry.minor_version == 2
assert entry.minor_version == 3
assert not entry.options
assert len(entry.subentries) == 1
subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation"
assert subentry.data == options
assert subentry.title == f"ChatGPT {idx + 1}"
assert len(entry.subentries) == 2
conversation_subentry = None
for subentry in entry.subentries.values():
if subentry.subentry_type == "conversation":
conversation_subentry = subentry
break
assert conversation_subentry is not None
assert conversation_subentry.subentry_type == "conversation"
assert conversation_subentry.data == options
assert conversation_subentry.title == f"ChatGPT {idx + 1}"
# Use conversation subentry for device assertions
subentry = conversation_subentry
dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
identifiers={(DOMAIN, subentry.subentry_id)}
)
assert dev is not None
assert dev.config_entries == {entry.entry_id}
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
async def test_migration_from_v1_to_v2_with_same_keys(
async def test_migration_from_v1_with_same_keys(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with same API keys consolidates entries."""
"""Test migration from version 1 with same API keys consolidates entries."""
# Create two v1 config entries with the same API key
options = {
"recommended": True,
@ -790,17 +819,28 @@ async def test_migration_from_v1_to_v2_with_same_keys(
entry = entries[0]
assert entry.version == 2
assert entry.minor_version == 2
assert entry.minor_version == 3
assert not entry.options
assert len(entry.subentries) == 2 # Two subentries from the two original entries
assert (
len(entry.subentries) == 3
) # Two conversation subentries + one AI task subentry
# Check both subentries exist with correct data
subentries = list(entry.subentries.values())
titles = [sub.title for sub in subentries]
# Check both conversation subentries exist with correct data
conversation_subentries = [
sub for sub in entry.subentries.values() if sub.subentry_type == "conversation"
]
ai_task_subentries = [
sub for sub in entry.subentries.values() if sub.subentry_type == "ai_task_data"
]
assert len(conversation_subentries) == 2
assert len(ai_task_subentries) == 1
titles = [sub.title for sub in conversation_subentries]
assert "ChatGPT" in titles
assert "ChatGPT 2" in titles
for subentry in subentries:
for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation"
assert subentry.data == options
@ -815,12 +855,12 @@ async def test_migration_from_v1_to_v2_with_same_keys(
}
async def test_migration_from_v2_1_to_v2_2(
async def test_migration_from_v2_1(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 2.1 to version 2.2.
"""Test migration from version 2.1.
This tests we clean up the broken migration in Home Assistant Core
2025.7.0b0-2025.7.0b1:
@ -913,16 +953,22 @@ async def test_migration_from_v2_1_to_v2_2(
assert len(entries) == 1
entry = entries[0]
assert entry.version == 2
assert entry.minor_version == 2
assert entry.minor_version == 3
assert not entry.options
assert entry.title == "ChatGPT"
assert len(entry.subentries) == 2
assert len(entry.subentries) == 3 # 2 conversation + 1 AI task
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "conversation"
]
ai_task_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "ai_task_data"
]
assert len(conversation_subentries) == 2
assert len(ai_task_subentries) == 1
for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation"
assert subentry.data == options
@ -972,7 +1018,9 @@ async def test_migration_from_v2_1_to_v2_2(
}
@pytest.mark.parametrize("mock_subentry_data", [{}, {CONF_CHAT_MODEL: "gpt-1o"}])
@pytest.mark.parametrize(
"mock_conversation_subentry_data", [{}, {CONF_CHAT_MODEL: "gpt-1o"}]
)
async def test_devices(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@ -980,12 +1028,89 @@ async def test_devices(
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Assert exception when invalid config entry is provided."""
"""Test devices are correctly created for subentries."""
devices = dr.async_entries_for_config_entry(
device_registry, mock_config_entry.entry_id
)
assert len(devices) == 1
assert len(devices) == 2 # One for conversation, one for AI task
# Use the first device for snapshot comparison
device = devices[0]
assert device == snapshot(exclude=props("identifiers"))
subentry = next(iter(mock_config_entry.subentries.values()))
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
# Verify the device has identifiers matching one of the subentries
expected_identifiers = [
{(DOMAIN, subentry.subentry_id)}
for subentry in mock_config_entry.subentries.values()
]
assert device.identifiers in expected_identifiers
async def test_migration_from_v2_2(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 2.2."""
# Create a v2.2 config entry with a conversation subentry
options = {
"recommended": True,
"llm_hass_api": ["assist"],
"prompt": "You are a helpful assistant",
"chat_model": "gpt-4o-mini",
}
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"api_key": "1234"},
entry_id="mock_entry_id",
version=2,
minor_version=2,
subentries_data=[
ConfigSubentryData(
data=options,
subentry_id="mock_id_1",
subentry_type="conversation",
title="ChatGPT",
unique_id=None,
),
],
title="ChatGPT",
)
mock_config_entry.add_to_hass(hass)
# Run migration
with patch(
"homeassistant.components.openai_conversation.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.version == 2
assert entry.minor_version == 3
assert not entry.options
assert entry.title == "ChatGPT"
assert len(entry.subentries) == 2
# Check conversation subentry is still there
conversation_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "conversation"
]
assert len(conversation_subentries) == 1
conversation_subentry = conversation_subentries[0]
assert conversation_subentry.data == options
# Check AI Task subentry was added
ai_task_subentries = [
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == "ai_task_data"
]
assert len(ai_task_subentries) == 1
ai_task_subentry = ai_task_subentries[0]
assert ai_task_subentry.data == {"recommended": True}
assert ai_task_subentry.title == "OpenAI AI Task"