mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Add OpenAI AI Task entity (#148295)
This commit is contained in:
parent
f0a636949a
commit
0e09a47476
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai.types.images_response import ImagesResponse
|
from openai.types.images_response import ImagesResponse
|
||||||
@ -45,9 +46,11 @@ from .const import (
|
|||||||
CONF_REASONING_EFFORT,
|
CONF_REASONING_EFFORT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_NAME,
|
DEFAULT_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_REASONING_EFFORT,
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
@ -59,7 +62,7 @@ from .entity import async_prepare_files_for_prompt
|
|||||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||||
|
|
||||||
PLATFORMS = (Platform.CONVERSATION,)
|
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
|
||||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||||
|
|
||||||
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
||||||
@ -153,7 +156,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
EasyInputMessageParam(type="message", role="user", content=content)
|
EasyInputMessageParam(type="message", role="user", content=content)
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
|
||||||
model_args = {
|
model_args = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"input": messages,
|
"input": messages,
|
||||||
@ -175,6 +177,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
response: Response = await client.responses.create(**model_args)
|
response: Response = await client.responses.create(**model_args)
|
||||||
|
|
||||||
except openai.OpenAIError as err:
|
except openai.OpenAIError as err:
|
||||||
@ -361,6 +364,18 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
|
|||||||
|
|
||||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||||
|
|
||||||
|
if entry.version == 2 and entry.minor_version == 2:
|
||||||
|
hass.config_entries.async_add_subentry(
|
||||||
|
entry,
|
||||||
|
ConfigSubentry(
|
||||||
|
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
|
||||||
|
subentry_type="ai_task_data",
|
||||||
|
title=DEFAULT_AI_TASK_NAME,
|
||||||
|
unique_id=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
hass.config_entries.async_update_entry(entry, minor_version=3)
|
||||||
|
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
||||||
)
|
)
|
||||||
|
77
homeassistant/components/openai_conversation/ai_task.py
Normal file
77
homeassistant/components/openai_conversation/ai_task.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""AI Task integration for OpenAI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from json import JSONDecodeError
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from homeassistant.components import ai_task, conversation
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
|
from .entity import OpenAIBaseLLMEntity
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up AI Task entities."""
|
||||||
|
for subentry in config_entry.subentries.values():
|
||||||
|
if subentry.subentry_type != "ai_task_data":
|
||||||
|
continue
|
||||||
|
|
||||||
|
async_add_entities(
|
||||||
|
[OpenAITaskEntity(config_entry, subentry)],
|
||||||
|
config_subentry_id=subentry.subentry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITaskEntity(
|
||||||
|
ai_task.AITaskEntity,
|
||||||
|
OpenAIBaseLLMEntity,
|
||||||
|
):
|
||||||
|
"""OpenAI AI Task entity."""
|
||||||
|
|
||||||
|
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
|
||||||
|
async def _async_generate_data(
|
||||||
|
self,
|
||||||
|
task: ai_task.GenDataTask,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> ai_task.GenDataTaskResult:
|
||||||
|
"""Handle a generate data task."""
|
||||||
|
await self._async_handle_chat_log(chat_log, task.name, task.structure)
|
||||||
|
|
||||||
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Last content in chat log is not an AssistantContent"
|
||||||
|
)
|
||||||
|
|
||||||
|
text = chat_log.content[-1].content or ""
|
||||||
|
|
||||||
|
if not task.structure:
|
||||||
|
return ai_task.GenDataTaskResult(
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
data=text,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
data = json_loads(text)
|
||||||
|
except JSONDecodeError as err:
|
||||||
|
_LOGGER.error(
|
||||||
|
"Failed to parse JSON response: %s. Response: %s",
|
||||||
|
err,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
raise HomeAssistantError("Error with OpenAI structured response") from err
|
||||||
|
|
||||||
|
return ai_task.GenDataTaskResult(
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
data=data,
|
||||||
|
)
|
@ -55,9 +55,12 @@ from .const import (
|
|||||||
CONF_WEB_SEARCH_REGION,
|
CONF_WEB_SEARCH_REGION,
|
||||||
CONF_WEB_SEARCH_TIMEZONE,
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
CONF_WEB_SEARCH_USER_LOCATION,
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_REASONING_EFFORT,
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
@ -77,12 +80,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
RECOMMENDED_OPTIONS = {
|
|
||||||
CONF_RECOMMENDED: True,
|
|
||||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
|
||||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||||
"""Validate the user input allows us to connect.
|
"""Validate the user input allows us to connect.
|
||||||
@ -99,7 +96,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
"""Handle a config flow for OpenAI Conversation."""
|
"""Handle a config flow for OpenAI Conversation."""
|
||||||
|
|
||||||
VERSION = 2
|
VERSION = 2
|
||||||
MINOR_VERSION = 2
|
MINOR_VERSION = 3
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
@ -129,10 +126,16 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
subentries=[
|
subentries=[
|
||||||
{
|
{
|
||||||
"subentry_type": "conversation",
|
"subentry_type": "conversation",
|
||||||
"data": RECOMMENDED_OPTIONS,
|
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
"title": DEFAULT_CONVERSATION_NAME,
|
"title": DEFAULT_CONVERSATION_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"subentry_type": "ai_task_data",
|
||||||
|
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
|
"title": DEFAULT_AI_TASK_NAME,
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -146,11 +149,14 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
cls, config_entry: ConfigEntry
|
cls, config_entry: ConfigEntry
|
||||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||||
"""Return subentries supported by this integration."""
|
"""Return subentries supported by this integration."""
|
||||||
return {"conversation": ConversationSubentryFlowHandler}
|
return {
|
||||||
|
"conversation": OpenAISubentryFlowHandler,
|
||||||
|
"ai_task_data": OpenAISubentryFlowHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
class OpenAISubentryFlowHandler(ConfigSubentryFlow):
|
||||||
"""Flow for managing conversation subentries."""
|
"""Flow for managing OpenAI subentries."""
|
||||||
|
|
||||||
last_rendered_recommended = False
|
last_rendered_recommended = False
|
||||||
options: dict[str, Any]
|
options: dict[str, Any]
|
||||||
@ -164,7 +170,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> SubentryFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""Add a subentry."""
|
"""Add a subentry."""
|
||||||
self.options = RECOMMENDED_OPTIONS.copy()
|
if self._subentry_type == "ai_task_data":
|
||||||
|
self.options = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
||||||
|
else:
|
||||||
|
self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||||
return await self.async_step_init()
|
return await self.async_step_init()
|
||||||
|
|
||||||
async def async_step_reconfigure(
|
async def async_step_reconfigure(
|
||||||
@ -181,6 +190,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
# abort if entry is not loaded
|
# abort if entry is not loaded
|
||||||
if self._get_entry().state != ConfigEntryState.LOADED:
|
if self._get_entry().state != ConfigEntryState.LOADED:
|
||||||
return self.async_abort(reason="entry_not_loaded")
|
return self.async_abort(reason="entry_not_loaded")
|
||||||
|
|
||||||
options = self.options
|
options = self.options
|
||||||
|
|
||||||
hass_apis: list[SelectOptionDict] = [
|
hass_apis: list[SelectOptionDict] = [
|
||||||
@ -198,10 +208,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
step_schema: VolDictType = {}
|
step_schema: VolDictType = {}
|
||||||
|
|
||||||
if self._is_new:
|
if self._is_new:
|
||||||
step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = (
|
if self._subentry_type == "ai_task_data":
|
||||||
str
|
default_name = DEFAULT_AI_TASK_NAME
|
||||||
)
|
else:
|
||||||
|
default_name = DEFAULT_CONVERSATION_NAME
|
||||||
|
step_schema[vol.Required(CONF_NAME, default=default_name)] = str
|
||||||
|
|
||||||
|
if self._subentry_type == "conversation":
|
||||||
step_schema.update(
|
step_schema.update(
|
||||||
{
|
{
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
@ -215,12 +228,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
|
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
|
||||||
SelectSelectorConfig(options=hass_apis, multiple=True)
|
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||||
),
|
),
|
||||||
vol.Required(
|
|
||||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
|
||||||
): bool,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
step_schema[
|
||||||
|
vol.Required(CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False))
|
||||||
|
] = bool
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if not user_input.get(CONF_LLM_HASS_API):
|
if not user_input.get(CONF_LLM_HASS_API):
|
||||||
user_input.pop(CONF_LLM_HASS_API, None)
|
user_input.pop(CONF_LLM_HASS_API, None)
|
||||||
@ -320,7 +334,9 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
elif CONF_REASONING_EFFORT in options:
|
elif CONF_REASONING_EFFORT in options:
|
||||||
options.pop(CONF_REASONING_EFFORT)
|
options.pop(CONF_REASONING_EFFORT)
|
||||||
|
|
||||||
if not model.startswith(tuple(UNSUPPORTED_WEB_SEARCH_MODELS)):
|
if self._subentry_type == "conversation" and not model.startswith(
|
||||||
|
tuple(UNSUPPORTED_WEB_SEARCH_MODELS)
|
||||||
|
):
|
||||||
step_schema.update(
|
step_schema.update(
|
||||||
{
|
{
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
@ -362,7 +378,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
if not step_schema:
|
if not step_schema:
|
||||||
if self._is_new:
|
if self._is_new:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
|
title=options.pop(CONF_NAME),
|
||||||
data=options,
|
data=options,
|
||||||
)
|
)
|
||||||
return self.async_update_and_abort(
|
return self.async_update_and_abort(
|
||||||
@ -384,7 +400,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
options.update(user_input)
|
options.update(user_input)
|
||||||
if self._is_new:
|
if self._is_new:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
|
title=options.pop(CONF_NAME),
|
||||||
data=options,
|
data=options,
|
||||||
)
|
)
|
||||||
return self.async_update_and_abort(
|
return self.async_update_and_abort(
|
||||||
|
@ -2,10 +2,14 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
DOMAIN = "openai_conversation"
|
DOMAIN = "openai_conversation"
|
||||||
LOGGER: logging.Logger = logging.getLogger(__package__)
|
LOGGER: logging.Logger = logging.getLogger(__package__)
|
||||||
|
|
||||||
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
|
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
|
||||||
|
DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
|
||||||
DEFAULT_NAME = "OpenAI Conversation"
|
DEFAULT_NAME = "OpenAI Conversation"
|
||||||
|
|
||||||
CONF_CHAT_MODEL = "chat_model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
@ -51,3 +55,12 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [
|
|||||||
"o1",
|
"o1",
|
||||||
"o3-mini",
|
"o3-mini",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||||
|
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||||
|
}
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
}
|
||||||
|
@ -39,6 +39,7 @@ from openai.types.responses import (
|
|||||||
)
|
)
|
||||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||||
from openai.types.responses.web_search_tool_param import UserLocation
|
from openai.types.responses.web_search_tool_param import UserLocation
|
||||||
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
@ -47,6 +48,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import device_registry as dr, llm
|
from homeassistant.helpers import device_registry as dr, llm
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
|
from homeassistant.util import slugify
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
@ -79,6 +81,47 @@ if TYPE_CHECKING:
|
|||||||
MAX_TOOL_ITERATIONS = 10
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _adjust_schema(schema: dict[str, Any]) -> None:
|
||||||
|
"""Adjust the schema to be compatible with OpenAI API."""
|
||||||
|
if schema["type"] == "object":
|
||||||
|
if "properties" not in schema:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "required" not in schema:
|
||||||
|
schema["required"] = []
|
||||||
|
|
||||||
|
# Ensure all properties are required
|
||||||
|
for prop, prop_info in schema["properties"].items():
|
||||||
|
_adjust_schema(prop_info)
|
||||||
|
if prop not in schema["required"]:
|
||||||
|
prop_info["type"] = [prop_info["type"], "null"]
|
||||||
|
schema["required"].append(prop)
|
||||||
|
|
||||||
|
elif schema["type"] == "array":
|
||||||
|
if "items" not in schema:
|
||||||
|
return
|
||||||
|
|
||||||
|
_adjust_schema(schema["items"])
|
||||||
|
|
||||||
|
|
||||||
|
def _format_structured_output(
|
||||||
|
schema: vol.Schema, llm_api: llm.APIInstance | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Format the schema to be compatible with OpenAI API."""
|
||||||
|
result: dict[str, Any] = convert(
|
||||||
|
schema,
|
||||||
|
custom_serializer=(
|
||||||
|
llm_api.custom_serializer if llm_api else llm.selector_serializer
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
_adjust_schema(result)
|
||||||
|
|
||||||
|
result["strict"] = True
|
||||||
|
result["additionalProperties"] = False
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(
|
def _format_tool(
|
||||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||||
) -> FunctionToolParam:
|
) -> FunctionToolParam:
|
||||||
@ -243,6 +286,8 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
async def _async_handle_chat_log(
|
async def _async_handle_chat_log(
|
||||||
self,
|
self,
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
|
structure_name: str | None = None,
|
||||||
|
structure: vol.Schema | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
@ -273,23 +318,10 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
tools = []
|
tools = []
|
||||||
tools.append(web_search)
|
tools.append(web_search)
|
||||||
|
|
||||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
|
||||||
messages = [
|
|
||||||
m
|
|
||||||
for content in chat_log.content
|
|
||||||
for m in _convert_content_to_param(content)
|
|
||||||
]
|
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
|
||||||
model_args = {
|
model_args = {
|
||||||
"model": model,
|
"model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
"input": messages,
|
"input": [],
|
||||||
"max_output_tokens": options.get(
|
"max_output_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
|
||||||
),
|
|
||||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
"user": chat_log.conversation_id,
|
"user": chat_log.conversation_id,
|
||||||
@ -299,13 +331,34 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
if tools:
|
if tools:
|
||||||
model_args["tools"] = tools
|
model_args["tools"] = tools
|
||||||
|
|
||||||
if model.startswith("o"):
|
if model_args["model"].startswith("o"):
|
||||||
model_args["reasoning"] = {
|
model_args["reasoning"] = {
|
||||||
"effort": options.get(
|
"effort": options.get(
|
||||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
model_args["include"] = ["reasoning.encrypted_content"]
|
else:
|
||||||
|
model_args["store"] = False
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
m
|
||||||
|
for content in chat_log.content
|
||||||
|
for m in _convert_content_to_param(content)
|
||||||
|
]
|
||||||
|
if structure and structure_name:
|
||||||
|
model_args["text"] = {
|
||||||
|
"format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"name": slugify(structure_name),
|
||||||
|
"schema": _format_structured_output(structure, chat_log.llm_api),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
|
model_args["input"] = messages
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await client.responses.create(**model_args)
|
result = await client.responses.create(**model_args)
|
||||||
|
@ -68,6 +68,52 @@
|
|||||||
"error": {
|
"error": {
|
||||||
"model_not_supported": "This model is not supported, please select a different model"
|
"model_not_supported": "This model is not supported, please select a different model"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"ai_task_data": {
|
||||||
|
"initiate_flow": {
|
||||||
|
"user": "Add Generate data with AI service",
|
||||||
|
"reconfigure": "Reconfigure Generate data with AI service"
|
||||||
|
},
|
||||||
|
"entry_type": "Generate data with AI service",
|
||||||
|
"step": {
|
||||||
|
"init": {
|
||||||
|
"data": {
|
||||||
|
"name": "[%key:common::config_flow::data::name%]",
|
||||||
|
"recommended": "[%key:component::openai_conversation::config_subentries::conversation::step::init::data::recommended%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"advanced": {
|
||||||
|
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::title%]",
|
||||||
|
"data": {
|
||||||
|
"chat_model": "[%key:common::generic::model%]",
|
||||||
|
"max_tokens": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::max_tokens%]",
|
||||||
|
"temperature": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::temperature%]",
|
||||||
|
"top_p": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::top_p%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::model::title%]",
|
||||||
|
"data": {
|
||||||
|
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::reasoning_effort%]",
|
||||||
|
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::web_search%]",
|
||||||
|
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::search_context_size%]",
|
||||||
|
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::user_location%]"
|
||||||
|
},
|
||||||
|
"data_description": {
|
||||||
|
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::reasoning_effort%]",
|
||||||
|
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::web_search%]",
|
||||||
|
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::search_context_size%]",
|
||||||
|
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::user_location%]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
||||||
|
"entry_not_loaded": "[%key:component::openai_conversation::config_subentries::conversation::abort::entry_not_loaded%]"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"model_not_supported": "[%key:component::openai_conversation::config_subentries::conversation::error::model_not_supported%]"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"selector": {
|
"selector": {
|
||||||
|
@ -1 +1,241 @@
|
|||||||
"""Tests for the OpenAI Conversation integration."""
|
"""Tests for the OpenAI Conversation integration."""
|
||||||
|
|
||||||
|
from openai.types.responses import (
|
||||||
|
ResponseContentPartAddedEvent,
|
||||||
|
ResponseContentPartDoneEvent,
|
||||||
|
ResponseFunctionCallArgumentsDeltaEvent,
|
||||||
|
ResponseFunctionCallArgumentsDoneEvent,
|
||||||
|
ResponseFunctionToolCall,
|
||||||
|
ResponseFunctionWebSearch,
|
||||||
|
ResponseOutputItemAddedEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseOutputMessage,
|
||||||
|
ResponseOutputText,
|
||||||
|
ResponseReasoningItem,
|
||||||
|
ResponseStreamEvent,
|
||||||
|
ResponseTextDeltaEvent,
|
||||||
|
ResponseTextDoneEvent,
|
||||||
|
ResponseWebSearchCallCompletedEvent,
|
||||||
|
ResponseWebSearchCallInProgressEvent,
|
||||||
|
ResponseWebSearchCallSearchingEvent,
|
||||||
|
)
|
||||||
|
from openai.types.responses.response_function_web_search import ActionSearch
|
||||||
|
|
||||||
|
|
||||||
|
def create_message_item(
|
||||||
|
id: str, text: str | list[str], output_index: int
|
||||||
|
) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a message item."""
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
content = ResponseOutputText(annotations=[], text="", type="output_text")
|
||||||
|
events = [
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
item=ResponseOutputMessage(
|
||||||
|
id=id,
|
||||||
|
content=[],
|
||||||
|
type="message",
|
||||||
|
role="assistant",
|
||||||
|
status="in_progress",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.added",
|
||||||
|
),
|
||||||
|
ResponseContentPartAddedEvent(
|
||||||
|
content_index=0,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
part=content,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.content_part.added",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
content.text = "".join(text)
|
||||||
|
events.extend(
|
||||||
|
ResponseTextDeltaEvent(
|
||||||
|
content_index=0,
|
||||||
|
delta=delta,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_text.delta",
|
||||||
|
)
|
||||||
|
for delta in text
|
||||||
|
)
|
||||||
|
|
||||||
|
events.extend(
|
||||||
|
[
|
||||||
|
ResponseTextDoneEvent(
|
||||||
|
content_index=0,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
text="".join(text),
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_text.done",
|
||||||
|
),
|
||||||
|
ResponseContentPartDoneEvent(
|
||||||
|
content_index=0,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
part=content,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.content_part.done",
|
||||||
|
),
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ResponseOutputMessage(
|
||||||
|
id=id,
|
||||||
|
content=[content],
|
||||||
|
role="assistant",
|
||||||
|
status="completed",
|
||||||
|
type="message",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def create_function_tool_call_item(
|
||||||
|
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
|
||||||
|
) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a function tool call item."""
|
||||||
|
if isinstance(arguments, str):
|
||||||
|
arguments = [arguments]
|
||||||
|
|
||||||
|
events = [
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
item=ResponseFunctionToolCall(
|
||||||
|
id=id,
|
||||||
|
arguments="",
|
||||||
|
call_id=call_id,
|
||||||
|
name=name,
|
||||||
|
type="function_call",
|
||||||
|
status="in_progress",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.added",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
events.extend(
|
||||||
|
ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
|
delta=delta,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.function_call_arguments.delta",
|
||||||
|
)
|
||||||
|
for delta in arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
events.append(
|
||||||
|
ResponseFunctionCallArgumentsDoneEvent(
|
||||||
|
arguments="".join(arguments),
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.function_call_arguments.done",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
events.append(
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ResponseFunctionToolCall(
|
||||||
|
id=id,
|
||||||
|
arguments="".join(arguments),
|
||||||
|
call_id=call_id,
|
||||||
|
name=name,
|
||||||
|
type="function_call",
|
||||||
|
status="completed",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.done",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a reasoning item."""
|
||||||
|
return [
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
item=ResponseReasoningItem(
|
||||||
|
id=id,
|
||||||
|
summary=[],
|
||||||
|
type="reasoning",
|
||||||
|
status=None,
|
||||||
|
encrypted_content="AAA",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.added",
|
||||||
|
),
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ResponseReasoningItem(
|
||||||
|
id=id,
|
||||||
|
summary=[],
|
||||||
|
type="reasoning",
|
||||||
|
status=None,
|
||||||
|
encrypted_content="AAABBB",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a web search call item."""
|
||||||
|
return [
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
item=ResponseFunctionWebSearch(
|
||||||
|
id=id,
|
||||||
|
status="in_progress",
|
||||||
|
action=ActionSearch(query="query", type="search"),
|
||||||
|
type="web_search_call",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.added",
|
||||||
|
),
|
||||||
|
ResponseWebSearchCallInProgressEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.web_search_call.in_progress",
|
||||||
|
),
|
||||||
|
ResponseWebSearchCallSearchingEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.web_search_call.searching",
|
||||||
|
),
|
||||||
|
ResponseWebSearchCallCompletedEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.web_search_call.completed",
|
||||||
|
),
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ResponseFunctionWebSearch(
|
||||||
|
id=id,
|
||||||
|
status="completed",
|
||||||
|
action=ActionSearch(query="query", type="search"),
|
||||||
|
type="web_search_call",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
@ -1,13 +1,30 @@
|
|||||||
"""Tests helpers."""
|
"""Tests helpers."""
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from openai.types import ResponseFormatText
|
||||||
|
from openai.types.responses import (
|
||||||
|
Response,
|
||||||
|
ResponseCompletedEvent,
|
||||||
|
ResponseCreatedEvent,
|
||||||
|
ResponseError,
|
||||||
|
ResponseErrorEvent,
|
||||||
|
ResponseFailedEvent,
|
||||||
|
ResponseIncompleteEvent,
|
||||||
|
ResponseInProgressEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseTextConfig,
|
||||||
|
)
|
||||||
|
from openai.types.responses.response import IncompleteDetails
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigSubentryData
|
from homeassistant.config_entries import ConfigSubentryData
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
@ -19,14 +36,14 @@ from tests.common import MockConfigEntry
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_subentry_data() -> dict[str, Any]:
|
def mock_conversation_subentry_data() -> dict[str, Any]:
|
||||||
"""Mock subentry data."""
|
"""Mock subentry data."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_entry(
|
def mock_config_entry(
|
||||||
hass: HomeAssistant, mock_subentry_data: dict[str, Any]
|
hass: HomeAssistant, mock_conversation_subentry_data: dict[str, Any]
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
"""Mock a config entry."""
|
"""Mock a config entry."""
|
||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
@ -36,13 +53,20 @@ def mock_config_entry(
|
|||||||
"api_key": "bla",
|
"api_key": "bla",
|
||||||
},
|
},
|
||||||
version=2,
|
version=2,
|
||||||
|
minor_version=3,
|
||||||
subentries_data=[
|
subentries_data=[
|
||||||
ConfigSubentryData(
|
ConfigSubentryData(
|
||||||
data=mock_subentry_data,
|
data=mock_conversation_subentry_data,
|
||||||
subentry_type="conversation",
|
subentry_type="conversation",
|
||||||
title=DEFAULT_CONVERSATION_NAME,
|
title=DEFAULT_CONVERSATION_NAME,
|
||||||
unique_id=None,
|
unique_id=None,
|
||||||
)
|
),
|
||||||
|
ConfigSubentryData(
|
||||||
|
data=RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
|
subentry_type="ai_task_data",
|
||||||
|
title=DEFAULT_AI_TASK_NAME,
|
||||||
|
unique_id=None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
@ -91,3 +115,94 @@ async def mock_init_component(
|
|||||||
async def setup_ha(hass: HomeAssistant) -> None:
|
async def setup_ha(hass: HomeAssistant) -> None:
|
||||||
"""Set up Home Assistant."""
|
"""Set up Home Assistant."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_create_stream() -> Generator[AsyncMock]:
|
||||||
|
"""Mock stream response."""
|
||||||
|
|
||||||
|
async def mock_generator(events, **kwargs):
|
||||||
|
response = Response(
|
||||||
|
id="resp_A",
|
||||||
|
created_at=1700000000,
|
||||||
|
error=None,
|
||||||
|
incomplete_details=None,
|
||||||
|
instructions=kwargs.get("instructions"),
|
||||||
|
metadata=kwargs.get("metadata", {}),
|
||||||
|
model=kwargs.get("model", "gpt-4o-mini"),
|
||||||
|
object="response",
|
||||||
|
output=[],
|
||||||
|
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
|
||||||
|
temperature=kwargs.get("temperature", 1.0),
|
||||||
|
tool_choice=kwargs.get("tool_choice", "auto"),
|
||||||
|
tools=kwargs.get("tools", []),
|
||||||
|
top_p=kwargs.get("top_p", 1.0),
|
||||||
|
max_output_tokens=kwargs.get("max_output_tokens", 100000),
|
||||||
|
previous_response_id=kwargs.get("previous_response_id"),
|
||||||
|
reasoning=kwargs.get("reasoning"),
|
||||||
|
status="in_progress",
|
||||||
|
text=kwargs.get(
|
||||||
|
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
|
||||||
|
),
|
||||||
|
truncation=kwargs.get("truncation", "disabled"),
|
||||||
|
usage=None,
|
||||||
|
user=kwargs.get("user"),
|
||||||
|
store=kwargs.get("store", True),
|
||||||
|
)
|
||||||
|
yield ResponseCreatedEvent(
|
||||||
|
response=response,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.created",
|
||||||
|
)
|
||||||
|
yield ResponseInProgressEvent(
|
||||||
|
response=response,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.in_progress",
|
||||||
|
)
|
||||||
|
response.status = "completed"
|
||||||
|
|
||||||
|
for value in events:
|
||||||
|
if isinstance(value, ResponseOutputItemDoneEvent):
|
||||||
|
response.output.append(value.item)
|
||||||
|
elif isinstance(value, IncompleteDetails):
|
||||||
|
response.status = "incomplete"
|
||||||
|
response.incomplete_details = value
|
||||||
|
break
|
||||||
|
if isinstance(value, ResponseError):
|
||||||
|
response.status = "failed"
|
||||||
|
response.error = value
|
||||||
|
break
|
||||||
|
|
||||||
|
yield value
|
||||||
|
|
||||||
|
if isinstance(value, ResponseErrorEvent):
|
||||||
|
return
|
||||||
|
|
||||||
|
if response.status == "incomplete":
|
||||||
|
yield ResponseIncompleteEvent(
|
||||||
|
response=response,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.incomplete",
|
||||||
|
)
|
||||||
|
elif response.status == "failed":
|
||||||
|
yield ResponseFailedEvent(
|
||||||
|
response=response,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.failed",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield ResponseCompletedEvent(
|
||||||
|
response=response,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"openai.resources.responses.AsyncResponses.create",
|
||||||
|
AsyncMock(),
|
||||||
|
) as mock_create:
|
||||||
|
mock_create.side_effect = lambda **kwargs: mock_generator(
|
||||||
|
mock_create.return_value.pop(0), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
yield mock_create
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_devices[mock_subentry_data0]
|
# name: test_devices[mock_conversation_subentry_data0]
|
||||||
DeviceRegistryEntrySnapshot({
|
DeviceRegistryEntrySnapshot({
|
||||||
'area_id': None,
|
'area_id': None,
|
||||||
'config_entries': <ANY>,
|
'config_entries': <ANY>,
|
||||||
@ -26,7 +26,7 @@
|
|||||||
'via_device_id': None,
|
'via_device_id': None,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_devices[mock_subentry_data1]
|
# name: test_devices[mock_conversation_subentry_data1]
|
||||||
DeviceRegistryEntrySnapshot({
|
DeviceRegistryEntrySnapshot({
|
||||||
'area_id': None,
|
'area_id': None,
|
||||||
'config_entries': <ANY>,
|
'config_entries': <ANY>,
|
||||||
|
124
tests/components/openai_conversation/test_ai_task.py
Normal file
124
tests/components/openai_conversation/test_ai_task.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""Test AI Task platform of OpenAI Conversation integration."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import ai_task
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import entity_registry as er, selector
|
||||||
|
|
||||||
|
from . import create_message_item
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task data generation."""
|
||||||
|
entity_id = "ai_task.openai_ai_task"
|
||||||
|
|
||||||
|
# Ensure entity is linked to the subentry
|
||||||
|
entity_entry = entity_registry.async_get(entity_id)
|
||||||
|
ai_task_entry = next(
|
||||||
|
iter(
|
||||||
|
entry
|
||||||
|
for entry in mock_config_entry.subentries.values()
|
||||||
|
if entry.subentry_type == "ai_task_data"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert entity_entry is not None
|
||||||
|
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
||||||
|
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
||||||
|
|
||||||
|
# Mock the OpenAI response stream
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
create_message_item(id="msg_A", text="The test data", output_index=0)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Generate test data",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.data == "The test data"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_structured_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task structured data generation."""
|
||||||
|
# Mock the OpenAI response stream with JSON data
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
create_message_item(
|
||||||
|
id="msg_A", text='{"characters": ["Mario", "Luigi"]}', output_index=0
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id="ai_task.openai_ai_task",
|
||||||
|
instructions="Generate test data",
|
||||||
|
structure=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("characters"): selector.selector(
|
||||||
|
{
|
||||||
|
"text": {
|
||||||
|
"multiple": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_invalid_structured_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task with invalid JSON response."""
|
||||||
|
# Mock the OpenAI response stream with invalid JSON
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
create_message_item(id="msg_A", text="INVALID JSON RESPONSE", output_index=0)
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match="Error with OpenAI structured response"
|
||||||
|
):
|
||||||
|
await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id="ai_task.openai_ai_task",
|
||||||
|
instructions="Generate test data",
|
||||||
|
structure=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("characters"): selector.selector(
|
||||||
|
{
|
||||||
|
"text": {
|
||||||
|
"multiple": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
@ -8,7 +8,9 @@ from openai.types.responses import Response, ResponseOutputMessage, ResponseOutp
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.openai_conversation.config_flow import RECOMMENDED_OPTIONS
|
from homeassistant.components.openai_conversation.config_flow import (
|
||||||
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
|
)
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
@ -24,8 +26,10 @@ from homeassistant.components.openai_conversation.const import (
|
|||||||
CONF_WEB_SEARCH_REGION,
|
CONF_WEB_SEARCH_REGION,
|
||||||
CONF_WEB_SEARCH_TIMEZONE,
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
CONF_WEB_SEARCH_USER_LOCATION,
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
@ -77,10 +81,16 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
assert result2["subentries"] == [
|
assert result2["subentries"] == [
|
||||||
{
|
{
|
||||||
"subentry_type": "conversation",
|
"subentry_type": "conversation",
|
||||||
"data": RECOMMENDED_OPTIONS,
|
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
"title": DEFAULT_CONVERSATION_NAME,
|
"title": DEFAULT_CONVERSATION_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"subentry_type": "ai_task_data",
|
||||||
|
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
|
"title": DEFAULT_AI_TASK_NAME,
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
@ -131,14 +141,14 @@ async def test_creating_conversation_subentry(
|
|||||||
|
|
||||||
result2 = await hass.config_entries.subentries.async_configure(
|
result2 = await hass.config_entries.subentries.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{"name": "My Custom Agent", **RECOMMENDED_OPTIONS},
|
{"name": "My Custom Agent", **RECOMMENDED_CONVERSATION_OPTIONS},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert result2["title"] == "My Custom Agent"
|
assert result2["title"] == "My Custom Agent"
|
||||||
|
|
||||||
processed_options = RECOMMENDED_OPTIONS.copy()
|
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||||
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
||||||
|
|
||||||
assert result2["data"] == processed_options
|
assert result2["data"] == processed_options
|
||||||
@ -709,3 +719,110 @@ async def test_subentry_web_search_user_location(
|
|||||||
CONF_WEB_SEARCH_COUNTRY: "US",
|
CONF_WEB_SEARCH_COUNTRY: "US",
|
||||||
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
CONF_WEB_SEARCH_TIMEZONE: "America/Los_Angeles",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_creating_ai_task_subentry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating an AI task subentry."""
|
||||||
|
old_subentries = set(mock_config_entry.subentries)
|
||||||
|
# Original conversation + original ai_task
|
||||||
|
assert len(mock_config_entry.subentries) == 2
|
||||||
|
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "ai_task_data"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("type") is FlowResultType.FORM
|
||||||
|
assert result.get("step_id") == "init"
|
||||||
|
assert not result.get("errors")
|
||||||
|
|
||||||
|
result2 = await hass.config_entries.subentries.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
"name": "Custom AI Task",
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert result2.get("type") is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result2.get("title") == "Custom AI Task"
|
||||||
|
assert result2.get("data") == {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(mock_config_entry.subentries) == 3
|
||||||
|
) # Original conversation + original ai_task + new ai_task
|
||||||
|
|
||||||
|
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||||
|
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||||
|
assert new_subentry.subentry_type == "ai_task_data"
|
||||||
|
assert new_subentry.title == "Custom AI Task"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ai_task_subentry_not_loaded(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating an AI task subentry when entry is not loaded."""
|
||||||
|
# Don't call mock_init_component to simulate not loaded state
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "ai_task_data"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("type") is FlowResultType.ABORT
|
||||||
|
assert result.get("reason") == "entry_not_loaded"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_creating_ai_task_subentry_advanced(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating an AI task subentry with advanced settings."""
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "ai_task_data"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("type") is FlowResultType.FORM
|
||||||
|
assert result.get("step_id") == "init"
|
||||||
|
|
||||||
|
# Go to advanced settings
|
||||||
|
result2 = await hass.config_entries.subentries.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
"name": "Advanced AI Task",
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result2.get("type") is FlowResultType.FORM
|
||||||
|
assert result2.get("step_id") == "advanced"
|
||||||
|
|
||||||
|
# Configure advanced settings
|
||||||
|
result3 = await hass.config_entries.subentries.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
CONF_CHAT_MODEL: "gpt-4o",
|
||||||
|
CONF_MAX_TOKENS: 200,
|
||||||
|
CONF_TEMPERATURE: 0.5,
|
||||||
|
CONF_TOP_P: 0.9,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result3.get("type") is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result3.get("title") == "Advanced AI Task"
|
||||||
|
assert result3.get("data") == {
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_CHAT_MODEL: "gpt-4o",
|
||||||
|
CONF_MAX_TOKENS: 200,
|
||||||
|
CONF_TEMPERATURE: 0.5,
|
||||||
|
CONF_TOP_P: 0.9,
|
||||||
|
}
|
||||||
|
@ -1,41 +1,15 @@
|
|||||||
"""Tests for the OpenAI integration."""
|
"""Tests for the OpenAI integration."""
|
||||||
|
|
||||||
from collections.abc import Generator
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AuthenticationError, RateLimitError
|
from openai import AuthenticationError, RateLimitError
|
||||||
from openai.types import ResponseFormatText
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
Response,
|
|
||||||
ResponseCompletedEvent,
|
|
||||||
ResponseContentPartAddedEvent,
|
|
||||||
ResponseContentPartDoneEvent,
|
|
||||||
ResponseCreatedEvent,
|
|
||||||
ResponseError,
|
ResponseError,
|
||||||
ResponseErrorEvent,
|
ResponseErrorEvent,
|
||||||
ResponseFailedEvent,
|
|
||||||
ResponseFunctionCallArgumentsDeltaEvent,
|
|
||||||
ResponseFunctionCallArgumentsDoneEvent,
|
|
||||||
ResponseFunctionToolCall,
|
|
||||||
ResponseFunctionWebSearch,
|
|
||||||
ResponseIncompleteEvent,
|
|
||||||
ResponseInProgressEvent,
|
|
||||||
ResponseOutputItemAddedEvent,
|
|
||||||
ResponseOutputItemDoneEvent,
|
|
||||||
ResponseOutputMessage,
|
|
||||||
ResponseOutputText,
|
|
||||||
ResponseReasoningItem,
|
|
||||||
ResponseStreamEvent,
|
ResponseStreamEvent,
|
||||||
ResponseTextConfig,
|
|
||||||
ResponseTextDeltaEvent,
|
|
||||||
ResponseTextDoneEvent,
|
|
||||||
ResponseWebSearchCallCompletedEvent,
|
|
||||||
ResponseWebSearchCallInProgressEvent,
|
|
||||||
ResponseWebSearchCallSearchingEvent,
|
|
||||||
)
|
)
|
||||||
from openai.types.responses.response import IncompleteDetails
|
from openai.types.responses.response import IncompleteDetails
|
||||||
from openai.types.responses.response_function_web_search import ActionSearch
|
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
@ -55,6 +29,13 @@ from homeassistant.core import Context, HomeAssistant
|
|||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from . import (
|
||||||
|
create_function_tool_call_item,
|
||||||
|
create_message_item,
|
||||||
|
create_reasoning_item,
|
||||||
|
create_web_search_item,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
from tests.components.conversation import (
|
from tests.components.conversation import (
|
||||||
MockChatLog,
|
MockChatLog,
|
||||||
@ -62,97 +43,6 @@ from tests.components.conversation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_create_stream() -> Generator[AsyncMock]:
|
|
||||||
"""Mock stream response."""
|
|
||||||
|
|
||||||
async def mock_generator(events, **kwargs):
|
|
||||||
response = Response(
|
|
||||||
id="resp_A",
|
|
||||||
created_at=1700000000,
|
|
||||||
error=None,
|
|
||||||
incomplete_details=None,
|
|
||||||
instructions=kwargs.get("instructions"),
|
|
||||||
metadata=kwargs.get("metadata", {}),
|
|
||||||
model=kwargs.get("model", "gpt-4o-mini"),
|
|
||||||
object="response",
|
|
||||||
output=[],
|
|
||||||
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
|
|
||||||
temperature=kwargs.get("temperature", 1.0),
|
|
||||||
tool_choice=kwargs.get("tool_choice", "auto"),
|
|
||||||
tools=kwargs.get("tools"),
|
|
||||||
top_p=kwargs.get("top_p", 1.0),
|
|
||||||
max_output_tokens=kwargs.get("max_output_tokens", 100000),
|
|
||||||
previous_response_id=kwargs.get("previous_response_id"),
|
|
||||||
reasoning=kwargs.get("reasoning"),
|
|
||||||
status="in_progress",
|
|
||||||
text=kwargs.get(
|
|
||||||
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
|
|
||||||
),
|
|
||||||
truncation=kwargs.get("truncation", "disabled"),
|
|
||||||
usage=None,
|
|
||||||
user=kwargs.get("user"),
|
|
||||||
store=kwargs.get("store", True),
|
|
||||||
)
|
|
||||||
yield ResponseCreatedEvent(
|
|
||||||
response=response,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.created",
|
|
||||||
)
|
|
||||||
yield ResponseInProgressEvent(
|
|
||||||
response=response,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.in_progress",
|
|
||||||
)
|
|
||||||
response.status = "completed"
|
|
||||||
|
|
||||||
for value in events:
|
|
||||||
if isinstance(value, ResponseOutputItemDoneEvent):
|
|
||||||
response.output.append(value.item)
|
|
||||||
elif isinstance(value, IncompleteDetails):
|
|
||||||
response.status = "incomplete"
|
|
||||||
response.incomplete_details = value
|
|
||||||
break
|
|
||||||
if isinstance(value, ResponseError):
|
|
||||||
response.status = "failed"
|
|
||||||
response.error = value
|
|
||||||
break
|
|
||||||
|
|
||||||
yield value
|
|
||||||
|
|
||||||
if isinstance(value, ResponseErrorEvent):
|
|
||||||
return
|
|
||||||
|
|
||||||
if response.status == "incomplete":
|
|
||||||
yield ResponseIncompleteEvent(
|
|
||||||
response=response,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.incomplete",
|
|
||||||
)
|
|
||||||
elif response.status == "failed":
|
|
||||||
yield ResponseFailedEvent(
|
|
||||||
response=response,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.failed",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield ResponseCompletedEvent(
|
|
||||||
response=response,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.completed",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"openai.resources.responses.AsyncResponses.create",
|
|
||||||
AsyncMock(),
|
|
||||||
) as mock_create:
|
|
||||||
mock_create.side_effect = lambda **kwargs: mock_generator(
|
|
||||||
mock_create.return_value.pop(0), **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
yield mock_create
|
|
||||||
|
|
||||||
|
|
||||||
async def test_entity(
|
async def test_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
@ -347,225 +237,6 @@ async def test_conversation_agent(
|
|||||||
assert agent.supported_languages == "*"
|
assert agent.supported_languages == "*"
|
||||||
|
|
||||||
|
|
||||||
def create_message_item(
|
|
||||||
id: str, text: str | list[str], output_index: int
|
|
||||||
) -> list[ResponseStreamEvent]:
|
|
||||||
"""Create a message item."""
|
|
||||||
if isinstance(text, str):
|
|
||||||
text = [text]
|
|
||||||
|
|
||||||
content = ResponseOutputText(annotations=[], text="", type="output_text")
|
|
||||||
events = [
|
|
||||||
ResponseOutputItemAddedEvent(
|
|
||||||
item=ResponseOutputMessage(
|
|
||||||
id=id,
|
|
||||||
content=[],
|
|
||||||
type="message",
|
|
||||||
role="assistant",
|
|
||||||
status="in_progress",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.added",
|
|
||||||
),
|
|
||||||
ResponseContentPartAddedEvent(
|
|
||||||
content_index=0,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
part=content,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.content_part.added",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
content.text = "".join(text)
|
|
||||||
events.extend(
|
|
||||||
ResponseTextDeltaEvent(
|
|
||||||
content_index=0,
|
|
||||||
delta=delta,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_text.delta",
|
|
||||||
)
|
|
||||||
for delta in text
|
|
||||||
)
|
|
||||||
|
|
||||||
events.extend(
|
|
||||||
[
|
|
||||||
ResponseTextDoneEvent(
|
|
||||||
content_index=0,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
text="".join(text),
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_text.done",
|
|
||||||
),
|
|
||||||
ResponseContentPartDoneEvent(
|
|
||||||
content_index=0,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
part=content,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.content_part.done",
|
|
||||||
),
|
|
||||||
ResponseOutputItemDoneEvent(
|
|
||||||
item=ResponseOutputMessage(
|
|
||||||
id=id,
|
|
||||||
content=[content],
|
|
||||||
role="assistant",
|
|
||||||
status="completed",
|
|
||||||
type="message",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.done",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return events
|
|
||||||
|
|
||||||
|
|
||||||
def create_function_tool_call_item(
|
|
||||||
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
|
|
||||||
) -> list[ResponseStreamEvent]:
|
|
||||||
"""Create a function tool call item."""
|
|
||||||
if isinstance(arguments, str):
|
|
||||||
arguments = [arguments]
|
|
||||||
|
|
||||||
events = [
|
|
||||||
ResponseOutputItemAddedEvent(
|
|
||||||
item=ResponseFunctionToolCall(
|
|
||||||
id=id,
|
|
||||||
arguments="",
|
|
||||||
call_id=call_id,
|
|
||||||
name=name,
|
|
||||||
type="function_call",
|
|
||||||
status="in_progress",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.added",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
events.extend(
|
|
||||||
ResponseFunctionCallArgumentsDeltaEvent(
|
|
||||||
delta=delta,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.function_call_arguments.delta",
|
|
||||||
)
|
|
||||||
for delta in arguments
|
|
||||||
)
|
|
||||||
|
|
||||||
events.append(
|
|
||||||
ResponseFunctionCallArgumentsDoneEvent(
|
|
||||||
arguments="".join(arguments),
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.function_call_arguments.done",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
events.append(
|
|
||||||
ResponseOutputItemDoneEvent(
|
|
||||||
item=ResponseFunctionToolCall(
|
|
||||||
id=id,
|
|
||||||
arguments="".join(arguments),
|
|
||||||
call_id=call_id,
|
|
||||||
name=name,
|
|
||||||
type="function_call",
|
|
||||||
status="completed",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.done",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return events
|
|
||||||
|
|
||||||
|
|
||||||
def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
|
||||||
"""Create a reasoning item."""
|
|
||||||
return [
|
|
||||||
ResponseOutputItemAddedEvent(
|
|
||||||
item=ResponseReasoningItem(
|
|
||||||
id=id,
|
|
||||||
summary=[],
|
|
||||||
type="reasoning",
|
|
||||||
status=None,
|
|
||||||
encrypted_content="AAA",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.added",
|
|
||||||
),
|
|
||||||
ResponseOutputItemDoneEvent(
|
|
||||||
item=ResponseReasoningItem(
|
|
||||||
id=id,
|
|
||||||
summary=[],
|
|
||||||
type="reasoning",
|
|
||||||
status=None,
|
|
||||||
encrypted_content="AAABBB",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.done",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]:
|
|
||||||
"""Create a web search call item."""
|
|
||||||
return [
|
|
||||||
ResponseOutputItemAddedEvent(
|
|
||||||
item=ResponseFunctionWebSearch(
|
|
||||||
id=id,
|
|
||||||
status="in_progress",
|
|
||||||
action=ActionSearch(query="query", type="search"),
|
|
||||||
type="web_search_call",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.added",
|
|
||||||
),
|
|
||||||
ResponseWebSearchCallInProgressEvent(
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.web_search_call.in_progress",
|
|
||||||
),
|
|
||||||
ResponseWebSearchCallSearchingEvent(
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.web_search_call.searching",
|
|
||||||
),
|
|
||||||
ResponseWebSearchCallCompletedEvent(
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.web_search_call.completed",
|
|
||||||
),
|
|
||||||
ResponseOutputItemDoneEvent(
|
|
||||||
item=ResponseFunctionWebSearch(
|
|
||||||
id=id,
|
|
||||||
status="completed",
|
|
||||||
action=ActionSearch(query="query", type="search"),
|
|
||||||
type="web_search_call",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
sequence_number=0,
|
|
||||||
type="response.output_item.done",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_function_call(
|
async def test_function_call(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry_with_reasoning_model: MockConfigEntry,
|
mock_config_entry_with_reasoning_model: MockConfigEntry,
|
||||||
|
77
tests/components/openai_conversation/test_entity.py
Normal file
77
tests/components/openai_conversation/test_entity.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Tests for the OpenAI Conversation entity."""
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components.openai_conversation.entity import (
|
||||||
|
_format_structured_output,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers import selector
|
||||||
|
|
||||||
|
|
||||||
|
async def test_format_structured_output() -> None:
|
||||||
|
"""Test the format_structured_output function."""
|
||||||
|
schema = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("name"): selector.TextSelector(),
|
||||||
|
vol.Optional("age"): selector.NumberSelector(
|
||||||
|
config=selector.NumberSelectorConfig(
|
||||||
|
min=0,
|
||||||
|
max=120,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
vol.Required("stuff"): selector.ObjectSelector(
|
||||||
|
{
|
||||||
|
"multiple": True,
|
||||||
|
"fields": {
|
||||||
|
"item_name": {
|
||||||
|
"selector": {"text": None},
|
||||||
|
},
|
||||||
|
"item_value": {
|
||||||
|
"selector": {"text": None},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert _format_structured_output(schema, None) == {
|
||||||
|
"additionalProperties": False,
|
||||||
|
"properties": {
|
||||||
|
"age": {
|
||||||
|
"maximum": 120.0,
|
||||||
|
"minimum": 0.0,
|
||||||
|
"type": [
|
||||||
|
"number",
|
||||||
|
"null",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"stuff": {
|
||||||
|
"items": {
|
||||||
|
"properties": {
|
||||||
|
"item_name": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
},
|
||||||
|
"item_value": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"item_name",
|
||||||
|
"item_value",
|
||||||
|
],
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
"type": "array",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"name",
|
||||||
|
"stuff",
|
||||||
|
"age",
|
||||||
|
],
|
||||||
|
"strict": True,
|
||||||
|
"type": "object",
|
||||||
|
}
|
@ -17,7 +17,10 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
from syrupy.filters import props
|
from syrupy.filters import props
|
||||||
|
|
||||||
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
|
from homeassistant.components.openai_conversation import CONF_CHAT_MODEL
|
||||||
from homeassistant.components.openai_conversation.const import DOMAIN
|
from homeassistant.components.openai_conversation.const import (
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
|
DOMAIN,
|
||||||
|
)
|
||||||
from homeassistant.config_entries import ConfigSubentryData
|
from homeassistant.config_entries import ConfigSubentryData
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
||||||
@ -534,7 +537,7 @@ async def test_generate_content_service_error(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v1_to_v2(
|
async def test_migration_from_v1(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
@ -582,17 +585,33 @@ async def test_migration_from_v1_to_v2(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert mock_config_entry.version == 2
|
assert mock_config_entry.version == 2
|
||||||
assert mock_config_entry.minor_version == 2
|
assert mock_config_entry.minor_version == 3
|
||||||
assert mock_config_entry.data == {"api_key": "1234"}
|
assert mock_config_entry.data == {"api_key": "1234"}
|
||||||
assert mock_config_entry.options == {}
|
assert mock_config_entry.options == {}
|
||||||
|
|
||||||
assert len(mock_config_entry.subentries) == 1
|
assert len(mock_config_entry.subentries) == 2
|
||||||
|
|
||||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
# Find the conversation subentry
|
||||||
assert subentry.unique_id is None
|
conversation_subentry = None
|
||||||
assert subentry.title == "ChatGPT"
|
ai_task_subentry = None
|
||||||
assert subentry.subentry_type == "conversation"
|
for subentry in mock_config_entry.subentries.values():
|
||||||
assert subentry.data == OPTIONS
|
if subentry.subentry_type == "conversation":
|
||||||
|
conversation_subentry = subentry
|
||||||
|
elif subentry.subentry_type == "ai_task_data":
|
||||||
|
ai_task_subentry = subentry
|
||||||
|
assert conversation_subentry is not None
|
||||||
|
assert conversation_subentry.unique_id is None
|
||||||
|
assert conversation_subentry.title == "ChatGPT"
|
||||||
|
assert conversation_subentry.subentry_type == "conversation"
|
||||||
|
assert conversation_subentry.data == OPTIONS
|
||||||
|
|
||||||
|
assert ai_task_subentry is not None
|
||||||
|
assert ai_task_subentry.unique_id is None
|
||||||
|
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
||||||
|
assert ai_task_subentry.subentry_type == "ai_task_data"
|
||||||
|
|
||||||
|
# Use conversation subentry for the rest of the assertions
|
||||||
|
subentry = conversation_subentry
|
||||||
|
|
||||||
migrated_entity = entity_registry.async_get(entity.entity_id)
|
migrated_entity = entity_registry.async_get(entity.entity_id)
|
||||||
assert migrated_entity is not None
|
assert migrated_entity is not None
|
||||||
@ -617,12 +636,12 @@ async def test_migration_from_v1_to_v2(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v1_to_v2_with_multiple_keys(
|
async def test_migration_from_v1_with_multiple_keys(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migration from version 1 to version 2 with different API keys."""
|
"""Test migration from version 1 with different API keys."""
|
||||||
# Create two v1 config entries with different API keys
|
# Create two v1 config entries with different API keys
|
||||||
options = {
|
options = {
|
||||||
"recommended": True,
|
"recommended": True,
|
||||||
@ -695,28 +714,38 @@ async def test_migration_from_v1_to_v2_with_multiple_keys(
|
|||||||
|
|
||||||
for idx, entry in enumerate(entries):
|
for idx, entry in enumerate(entries):
|
||||||
assert entry.version == 2
|
assert entry.version == 2
|
||||||
assert entry.minor_version == 2
|
assert entry.minor_version == 3
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert len(entry.subentries) == 1
|
assert len(entry.subentries) == 2
|
||||||
subentry = list(entry.subentries.values())[0]
|
|
||||||
assert subentry.subentry_type == "conversation"
|
conversation_subentry = None
|
||||||
assert subentry.data == options
|
for subentry in entry.subentries.values():
|
||||||
assert subentry.title == f"ChatGPT {idx + 1}"
|
if subentry.subentry_type == "conversation":
|
||||||
|
conversation_subentry = subentry
|
||||||
|
break
|
||||||
|
|
||||||
|
assert conversation_subentry is not None
|
||||||
|
assert conversation_subentry.subentry_type == "conversation"
|
||||||
|
assert conversation_subentry.data == options
|
||||||
|
assert conversation_subentry.title == f"ChatGPT {idx + 1}"
|
||||||
|
|
||||||
|
# Use conversation subentry for device assertions
|
||||||
|
subentry = conversation_subentry
|
||||||
|
|
||||||
dev = device_registry.async_get_device(
|
dev = device_registry.async_get_device(
|
||||||
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
identifiers={(DOMAIN, subentry.subentry_id)}
|
||||||
)
|
)
|
||||||
assert dev is not None
|
assert dev is not None
|
||||||
assert dev.config_entries == {entry.entry_id}
|
assert dev.config_entries == {entry.entry_id}
|
||||||
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
|
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v1_to_v2_with_same_keys(
|
async def test_migration_from_v1_with_same_keys(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migration from version 1 to version 2 with same API keys consolidates entries."""
|
"""Test migration from version 1 with same API keys consolidates entries."""
|
||||||
# Create two v1 config entries with the same API key
|
# Create two v1 config entries with the same API key
|
||||||
options = {
|
options = {
|
||||||
"recommended": True,
|
"recommended": True,
|
||||||
@ -790,17 +819,28 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
|||||||
|
|
||||||
entry = entries[0]
|
entry = entries[0]
|
||||||
assert entry.version == 2
|
assert entry.version == 2
|
||||||
assert entry.minor_version == 2
|
assert entry.minor_version == 3
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert len(entry.subentries) == 2 # Two subentries from the two original entries
|
assert (
|
||||||
|
len(entry.subentries) == 3
|
||||||
|
) # Two conversation subentries + one AI task subentry
|
||||||
|
|
||||||
# Check both subentries exist with correct data
|
# Check both conversation subentries exist with correct data
|
||||||
subentries = list(entry.subentries.values())
|
conversation_subentries = [
|
||||||
titles = [sub.title for sub in subentries]
|
sub for sub in entry.subentries.values() if sub.subentry_type == "conversation"
|
||||||
|
]
|
||||||
|
ai_task_subentries = [
|
||||||
|
sub for sub in entry.subentries.values() if sub.subentry_type == "ai_task_data"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(conversation_subentries) == 2
|
||||||
|
assert len(ai_task_subentries) == 1
|
||||||
|
|
||||||
|
titles = [sub.title for sub in conversation_subentries]
|
||||||
assert "ChatGPT" in titles
|
assert "ChatGPT" in titles
|
||||||
assert "ChatGPT 2" in titles
|
assert "ChatGPT 2" in titles
|
||||||
|
|
||||||
for subentry in subentries:
|
for subentry in conversation_subentries:
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
assert subentry.data == options
|
assert subentry.data == options
|
||||||
|
|
||||||
@ -815,12 +855,12 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v2_1_to_v2_2(
|
async def test_migration_from_v2_1(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migration from version 2.1 to version 2.2.
|
"""Test migration from version 2.1.
|
||||||
|
|
||||||
This tests we clean up the broken migration in Home Assistant Core
|
This tests we clean up the broken migration in Home Assistant Core
|
||||||
2025.7.0b0-2025.7.0b1:
|
2025.7.0b0-2025.7.0b1:
|
||||||
@ -913,16 +953,22 @@ async def test_migration_from_v2_1_to_v2_2(
|
|||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
entry = entries[0]
|
entry = entries[0]
|
||||||
assert entry.version == 2
|
assert entry.version == 2
|
||||||
assert entry.minor_version == 2
|
assert entry.minor_version == 3
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert entry.title == "ChatGPT"
|
assert entry.title == "ChatGPT"
|
||||||
assert len(entry.subentries) == 2
|
assert len(entry.subentries) == 3 # 2 conversation + 1 AI task
|
||||||
conversation_subentries = [
|
conversation_subentries = [
|
||||||
subentry
|
subentry
|
||||||
for subentry in entry.subentries.values()
|
for subentry in entry.subentries.values()
|
||||||
if subentry.subentry_type == "conversation"
|
if subentry.subentry_type == "conversation"
|
||||||
]
|
]
|
||||||
|
ai_task_subentries = [
|
||||||
|
subentry
|
||||||
|
for subentry in entry.subentries.values()
|
||||||
|
if subentry.subentry_type == "ai_task_data"
|
||||||
|
]
|
||||||
assert len(conversation_subentries) == 2
|
assert len(conversation_subentries) == 2
|
||||||
|
assert len(ai_task_subentries) == 1
|
||||||
for subentry in conversation_subentries:
|
for subentry in conversation_subentries:
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
assert subentry.data == options
|
assert subentry.data == options
|
||||||
@ -972,7 +1018,9 @@ async def test_migration_from_v2_1_to_v2_2(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("mock_subentry_data", [{}, {CONF_CHAT_MODEL: "gpt-1o"}])
|
@pytest.mark.parametrize(
|
||||||
|
"mock_conversation_subentry_data", [{}, {CONF_CHAT_MODEL: "gpt-1o"}]
|
||||||
|
)
|
||||||
async def test_devices(
|
async def test_devices(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
@ -980,12 +1028,89 @@ async def test_devices(
|
|||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Assert exception when invalid config entry is provided."""
|
"""Test devices are correctly created for subentries."""
|
||||||
devices = dr.async_entries_for_config_entry(
|
devices = dr.async_entries_for_config_entry(
|
||||||
device_registry, mock_config_entry.entry_id
|
device_registry, mock_config_entry.entry_id
|
||||||
)
|
)
|
||||||
assert len(devices) == 1
|
assert len(devices) == 2 # One for conversation, one for AI task
|
||||||
|
|
||||||
|
# Use the first device for snapshot comparison
|
||||||
device = devices[0]
|
device = devices[0]
|
||||||
assert device == snapshot(exclude=props("identifiers"))
|
assert device == snapshot(exclude=props("identifiers"))
|
||||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
# Verify the device has identifiers matching one of the subentries
|
||||||
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
|
expected_identifiers = [
|
||||||
|
{(DOMAIN, subentry.subentry_id)}
|
||||||
|
for subentry in mock_config_entry.subentries.values()
|
||||||
|
]
|
||||||
|
assert device.identifiers in expected_identifiers
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migration_from_v2_2(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test migration from version 2.2."""
|
||||||
|
# Create a v2.2 config entry with a conversation subentry
|
||||||
|
options = {
|
||||||
|
"recommended": True,
|
||||||
|
"llm_hass_api": ["assist"],
|
||||||
|
"prompt": "You are a helpful assistant",
|
||||||
|
"chat_model": "gpt-4o-mini",
|
||||||
|
}
|
||||||
|
mock_config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
data={"api_key": "1234"},
|
||||||
|
entry_id="mock_entry_id",
|
||||||
|
version=2,
|
||||||
|
minor_version=2,
|
||||||
|
subentries_data=[
|
||||||
|
ConfigSubentryData(
|
||||||
|
data=options,
|
||||||
|
subentry_id="mock_id_1",
|
||||||
|
subentry_type="conversation",
|
||||||
|
title="ChatGPT",
|
||||||
|
unique_id=None,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
title="ChatGPT",
|
||||||
|
)
|
||||||
|
mock_config_entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
# Run migration
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.openai_conversation.async_setup_entry",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
entry = entries[0]
|
||||||
|
assert entry.version == 2
|
||||||
|
assert entry.minor_version == 3
|
||||||
|
assert not entry.options
|
||||||
|
assert entry.title == "ChatGPT"
|
||||||
|
assert len(entry.subentries) == 2
|
||||||
|
|
||||||
|
# Check conversation subentry is still there
|
||||||
|
conversation_subentries = [
|
||||||
|
subentry
|
||||||
|
for subentry in entry.subentries.values()
|
||||||
|
if subentry.subentry_type == "conversation"
|
||||||
|
]
|
||||||
|
assert len(conversation_subentries) == 1
|
||||||
|
conversation_subentry = conversation_subentries[0]
|
||||||
|
assert conversation_subentry.data == options
|
||||||
|
|
||||||
|
# Check AI Task subentry was added
|
||||||
|
ai_task_subentries = [
|
||||||
|
subentry
|
||||||
|
for subentry in entry.subentries.values()
|
||||||
|
if subentry.subentry_type == "ai_task_data"
|
||||||
|
]
|
||||||
|
assert len(ai_task_subentries) == 1
|
||||||
|
ai_task_subentry = ai_task_subentries[0]
|
||||||
|
assert ai_task_subentry.data == {"recommended": True}
|
||||||
|
assert ai_task_subentry.title == "OpenAI AI Task"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user