mirror of
https://github.com/home-assistant/core.git
synced 2025-08-01 09:38:21 +00:00
Add AI Task to OpenRouter (#149275)
This commit is contained in:
parent
223c34056d
commit
1b58809655
@ -12,7 +12,7 @@ from homeassistant.helpers.httpx_client import get_async_client
|
|||||||
|
|
||||||
from .const import LOGGER
|
from .const import LOGGER
|
||||||
|
|
||||||
PLATFORMS = [Platform.CONVERSATION]
|
PLATFORMS = [Platform.AI_TASK, Platform.CONVERSATION]
|
||||||
|
|
||||||
type OpenRouterConfigEntry = ConfigEntry[AsyncOpenAI]
|
type OpenRouterConfigEntry = ConfigEntry[AsyncOpenAI]
|
||||||
|
|
||||||
|
75
homeassistant/components/open_router/ai_task.py
Normal file
75
homeassistant/components/open_router/ai_task.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""AI Task integration for OpenRouter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from json import JSONDecodeError
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from homeassistant.components import ai_task, conversation
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
|
from . import OpenRouterConfigEntry
|
||||||
|
from .entity import OpenRouterEntity
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: OpenRouterConfigEntry,
|
||||||
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up AI Task entities."""
|
||||||
|
for subentry in config_entry.subentries.values():
|
||||||
|
if subentry.subentry_type != "ai_task_data":
|
||||||
|
continue
|
||||||
|
|
||||||
|
async_add_entities(
|
||||||
|
[OpenRouterAITaskEntity(config_entry, subentry)],
|
||||||
|
config_subentry_id=subentry.subentry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterAITaskEntity(
|
||||||
|
ai_task.AITaskEntity,
|
||||||
|
OpenRouterEntity,
|
||||||
|
):
|
||||||
|
"""OpenRouter AI Task entity."""
|
||||||
|
|
||||||
|
_attr_name = None
|
||||||
|
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
|
||||||
|
async def _async_generate_data(
|
||||||
|
self,
|
||||||
|
task: ai_task.GenDataTask,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> ai_task.GenDataTaskResult:
|
||||||
|
"""Handle a generate data task."""
|
||||||
|
await self._async_handle_chat_log(chat_log, task.name, task.structure)
|
||||||
|
|
||||||
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Last content in chat log is not an AssistantContent"
|
||||||
|
)
|
||||||
|
|
||||||
|
text = chat_log.content[-1].content or ""
|
||||||
|
|
||||||
|
if not task.structure:
|
||||||
|
return ai_task.GenDataTaskResult(
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
data=text,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
data = json_loads(text)
|
||||||
|
except JSONDecodeError as err:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Error with OpenRouter structured response"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
return ai_task.GenDataTaskResult(
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
data=data,
|
||||||
|
)
|
@ -5,7 +5,12 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from python_open_router import Model, OpenRouterClient, OpenRouterError
|
from python_open_router import (
|
||||||
|
Model,
|
||||||
|
OpenRouterClient,
|
||||||
|
OpenRouterError,
|
||||||
|
SupportedParameter,
|
||||||
|
)
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import (
|
from homeassistant.config_entries import (
|
||||||
@ -43,7 +48,10 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
cls, config_entry: ConfigEntry
|
cls, config_entry: ConfigEntry
|
||||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||||
"""Return subentries supported by this handler."""
|
"""Return subentries supported by this handler."""
|
||||||
return {"conversation": ConversationFlowHandler}
|
return {
|
||||||
|
"conversation": ConversationFlowHandler,
|
||||||
|
"ai_task_data": AITaskDataFlowHandler,
|
||||||
|
}
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
@ -78,13 +86,26 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationFlowHandler(ConfigSubentryFlow):
|
class OpenRouterSubentryFlowHandler(ConfigSubentryFlow):
|
||||||
"""Handle subentry flow."""
|
"""Handle subentry flow for OpenRouter."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the subentry flow."""
|
"""Initialize the subentry flow."""
|
||||||
self.models: dict[str, Model] = {}
|
self.models: dict[str, Model] = {}
|
||||||
|
|
||||||
|
async def _get_models(self) -> None:
|
||||||
|
"""Fetch models from OpenRouter."""
|
||||||
|
entry = self._get_entry()
|
||||||
|
client = OpenRouterClient(
|
||||||
|
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
|
||||||
|
)
|
||||||
|
models = await client.get_models()
|
||||||
|
self.models = {model.id: model for model in models}
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationFlowHandler(OpenRouterSubentryFlowHandler):
|
||||||
|
"""Handle subentry flow."""
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> SubentryFlowResult:
|
) -> SubentryFlowResult:
|
||||||
@ -95,14 +116,16 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=self.models[user_input[CONF_MODEL]].name, data=user_input
|
title=self.models[user_input[CONF_MODEL]].name, data=user_input
|
||||||
)
|
)
|
||||||
entry = self._get_entry()
|
try:
|
||||||
client = OpenRouterClient(
|
await self._get_models()
|
||||||
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
|
except OpenRouterError:
|
||||||
)
|
return self.async_abort(reason="cannot_connect")
|
||||||
models = await client.get_models()
|
except Exception:
|
||||||
self.models = {model.id: model for model in models}
|
_LOGGER.exception("Unexpected exception")
|
||||||
|
return self.async_abort(reason="unknown")
|
||||||
options = [
|
options = [
|
||||||
SelectOptionDict(value=model.id, label=model.name) for model in models
|
SelectOptionDict(value=model.id, label=model.name)
|
||||||
|
for model in self.models.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
hass_apis: list[SelectOptionDict] = [
|
hass_apis: list[SelectOptionDict] = [
|
||||||
@ -138,3 +161,40 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AITaskDataFlowHandler(OpenRouterSubentryFlowHandler):
|
||||||
|
"""Handle subentry flow."""
|
||||||
|
|
||||||
|
async def async_step_user(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> SubentryFlowResult:
|
||||||
|
"""User flow to create a sensor subentry."""
|
||||||
|
if user_input is not None:
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=self.models[user_input[CONF_MODEL]].name, data=user_input
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self._get_models()
|
||||||
|
except OpenRouterError:
|
||||||
|
return self.async_abort(reason="cannot_connect")
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception("Unexpected exception")
|
||||||
|
return self.async_abort(reason="unknown")
|
||||||
|
options = [
|
||||||
|
SelectOptionDict(value=model.id, label=model.name)
|
||||||
|
for model in self.models.values()
|
||||||
|
if SupportedParameter.STRUCTURED_OUTPUTS in model.supported_parameters
|
||||||
|
]
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="user",
|
||||||
|
data_schema=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required(CONF_MODEL): SelectSelector(
|
||||||
|
SelectSelectorConfig(
|
||||||
|
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@ -20,6 +20,8 @@ async def async_setup_entry(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Set up conversation entities."""
|
"""Set up conversation entities."""
|
||||||
for subentry_id, subentry in config_entry.subentries.items():
|
for subentry_id, subentry in config_entry.subentries.items():
|
||||||
|
if subentry.subentry_type != "conversation":
|
||||||
|
continue
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
[OpenRouterConversationEntity(config_entry, subentry)],
|
[OpenRouterConversationEntity(config_entry, subentry)],
|
||||||
config_subentry_id=subentry_id,
|
config_subentry_id=subentry_id,
|
||||||
|
@ -4,10 +4,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
import json
|
import json
|
||||||
from typing import Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai import NOT_GIVEN
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam,
|
ChatCompletionAssistantMessageParam,
|
||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
@ -19,7 +18,9 @@ from openai.types.chat import (
|
|||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||||
from openai.types.shared_params import FunctionDefinition
|
from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema
|
||||||
|
from openai.types.shared_params.response_format_json_schema import JSONSchema
|
||||||
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
@ -36,6 +37,50 @@ from .const import DOMAIN, LOGGER
|
|||||||
MAX_TOOL_ITERATIONS = 10
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _adjust_schema(schema: dict[str, Any]) -> None:
|
||||||
|
"""Adjust the schema to be compatible with OpenRouter API."""
|
||||||
|
if schema["type"] == "object":
|
||||||
|
if "properties" not in schema:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "required" not in schema:
|
||||||
|
schema["required"] = []
|
||||||
|
|
||||||
|
# Ensure all properties are required
|
||||||
|
for prop, prop_info in schema["properties"].items():
|
||||||
|
_adjust_schema(prop_info)
|
||||||
|
if prop not in schema["required"]:
|
||||||
|
prop_info["type"] = [prop_info["type"], "null"]
|
||||||
|
schema["required"].append(prop)
|
||||||
|
|
||||||
|
elif schema["type"] == "array":
|
||||||
|
if "items" not in schema:
|
||||||
|
return
|
||||||
|
|
||||||
|
_adjust_schema(schema["items"])
|
||||||
|
|
||||||
|
|
||||||
|
def _format_structured_output(
|
||||||
|
name: str, schema: vol.Schema, llm_api: llm.APIInstance | None
|
||||||
|
) -> JSONSchema:
|
||||||
|
"""Format the schema to be compatible with OpenRouter API."""
|
||||||
|
result: JSONSchema = {
|
||||||
|
"name": name,
|
||||||
|
"strict": True,
|
||||||
|
}
|
||||||
|
result_schema = convert(
|
||||||
|
schema,
|
||||||
|
custom_serializer=(
|
||||||
|
llm_api.custom_serializer if llm_api else llm.selector_serializer
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
_adjust_schema(result_schema)
|
||||||
|
|
||||||
|
result["schema"] = result_schema
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(
|
def _format_tool(
|
||||||
tool: llm.Tool,
|
tool: llm.Tool,
|
||||||
custom_serializer: Callable[[Any], Any] | None,
|
custom_serializer: Callable[[Any], Any] | None,
|
||||||
@ -136,9 +181,24 @@ class OpenRouterEntity(Entity):
|
|||||||
entry_type=dr.DeviceEntryType.SERVICE,
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None:
|
async def _async_handle_chat_log(
|
||||||
|
self,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
structure_name: str | None = None,
|
||||||
|
structure: vol.Schema | None = None,
|
||||||
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
|
|
||||||
|
model_args = {
|
||||||
|
"model": self.model,
|
||||||
|
"user": chat_log.conversation_id,
|
||||||
|
"extra_headers": {
|
||||||
|
"X-Title": "Home Assistant",
|
||||||
|
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||||
|
},
|
||||||
|
"extra_body": {"require_parameters": True},
|
||||||
|
}
|
||||||
|
|
||||||
tools: list[ChatCompletionToolParam] | None = None
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
if chat_log.llm_api:
|
if chat_log.llm_api:
|
||||||
tools = [
|
tools = [
|
||||||
@ -146,33 +206,37 @@ class OpenRouterEntity(Entity):
|
|||||||
for tool in chat_log.llm_api.tools
|
for tool in chat_log.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
messages = [
|
if tools:
|
||||||
|
model_args["tools"] = tools
|
||||||
|
|
||||||
|
model_args["messages"] = [
|
||||||
m
|
m
|
||||||
for content in chat_log.content
|
for content in chat_log.content
|
||||||
if (m := _convert_content_to_chat_message(content))
|
if (m := _convert_content_to_chat_message(content))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if structure:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert structure_name is not None
|
||||||
|
model_args["response_format"] = ResponseFormatJSONSchema(
|
||||||
|
type="json_schema",
|
||||||
|
json_schema=_format_structured_output(
|
||||||
|
structure_name, structure, chat_log.llm_api
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
try:
|
||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(**model_args)
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
tools=tools or NOT_GIVEN,
|
|
||||||
user=chat_log.conversation_id,
|
|
||||||
extra_headers={
|
|
||||||
"X-Title": "Home Assistant",
|
|
||||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except openai.OpenAIError as err:
|
except openai.OpenAIError as err:
|
||||||
LOGGER.error("Error talking to API: %s", err)
|
LOGGER.error("Error talking to API: %s", err)
|
||||||
raise HomeAssistantError("Error talking to API") from err
|
raise HomeAssistantError("Error talking to API") from err
|
||||||
|
|
||||||
result_message = result.choices[0].message
|
result_message = result.choices[0].message
|
||||||
|
|
||||||
messages.extend(
|
model_args["messages"].extend(
|
||||||
[
|
[
|
||||||
msg
|
msg
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
|
@ -37,7 +37,28 @@
|
|||||||
"initiate_flow": {
|
"initiate_flow": {
|
||||||
"user": "Add conversation agent"
|
"user": "Add conversation agent"
|
||||||
},
|
},
|
||||||
"entry_type": "Conversation agent"
|
"entry_type": "Conversation agent",
|
||||||
|
"abort": {
|
||||||
|
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||||
|
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ai_task_data": {
|
||||||
|
"step": {
|
||||||
|
"user": {
|
||||||
|
"data": {
|
||||||
|
"model": "[%key:component::open_router::config_subentries::conversation::step::user::data::model%]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"initiate_flow": {
|
||||||
|
"user": "Add Generate data with AI service"
|
||||||
|
},
|
||||||
|
"entry_type": "Generate data with AI service",
|
||||||
|
"abort": {
|
||||||
|
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||||
|
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,9 +49,19 @@ def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ai_task_data_subentry_data() -> dict[str, Any]:
|
||||||
|
"""Mock AI task subentry data."""
|
||||||
|
return {
|
||||||
|
CONF_MODEL: "google/gemini-1.5-pro",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_entry(
|
def mock_config_entry(
|
||||||
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
|
hass: HomeAssistant,
|
||||||
|
conversation_subentry_data: dict[str, Any],
|
||||||
|
ai_task_data_subentry_data: dict[str, Any],
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
"""Mock a config entry."""
|
"""Mock a config entry."""
|
||||||
return MockConfigEntry(
|
return MockConfigEntry(
|
||||||
@ -67,7 +77,14 @@ def mock_config_entry(
|
|||||||
subentry_type="conversation",
|
subentry_type="conversation",
|
||||||
title="GPT-3.5 Turbo",
|
title="GPT-3.5 Turbo",
|
||||||
unique_id=None,
|
unique_id=None,
|
||||||
)
|
),
|
||||||
|
ConfigSubentryData(
|
||||||
|
data=ai_task_data_subentry_data,
|
||||||
|
subentry_id="ABCDEG",
|
||||||
|
subentry_type="ai_task_data",
|
||||||
|
title="Gemini 1.5 Pro",
|
||||||
|
unique_id=None,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,6 +85,7 @@
|
|||||||
"logit_bias",
|
"logit_bias",
|
||||||
"logprobs",
|
"logprobs",
|
||||||
"top_logprobs",
|
"top_logprobs",
|
||||||
|
"structured_outputs",
|
||||||
"response_format"
|
"response_format"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
53
tests/components/open_router/snapshots/test_ai_task.ambr
Normal file
53
tests/components/open_router/snapshots/test_ai_task.ambr
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_all_entities[ai_task.gemini_1_5_pro-entry]
|
||||||
|
EntityRegistryEntrySnapshot({
|
||||||
|
'aliases': set({
|
||||||
|
}),
|
||||||
|
'area_id': None,
|
||||||
|
'capabilities': None,
|
||||||
|
'config_entry_id': <ANY>,
|
||||||
|
'config_subentry_id': <ANY>,
|
||||||
|
'device_class': None,
|
||||||
|
'device_id': <ANY>,
|
||||||
|
'disabled_by': None,
|
||||||
|
'domain': 'ai_task',
|
||||||
|
'entity_category': None,
|
||||||
|
'entity_id': 'ai_task.gemini_1_5_pro',
|
||||||
|
'has_entity_name': True,
|
||||||
|
'hidden_by': None,
|
||||||
|
'icon': None,
|
||||||
|
'id': <ANY>,
|
||||||
|
'labels': set({
|
||||||
|
}),
|
||||||
|
'name': None,
|
||||||
|
'options': dict({
|
||||||
|
'conversation': dict({
|
||||||
|
'should_expose': False,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'original_device_class': None,
|
||||||
|
'original_icon': None,
|
||||||
|
'original_name': None,
|
||||||
|
'platform': 'open_router',
|
||||||
|
'previous_unique_id': None,
|
||||||
|
'suggested_object_id': None,
|
||||||
|
'supported_features': <AITaskEntityFeature: 1>,
|
||||||
|
'translation_key': None,
|
||||||
|
'unique_id': 'ABCDEG',
|
||||||
|
'unit_of_measurement': None,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_all_entities[ai_task.gemini_1_5_pro-state]
|
||||||
|
StateSnapshot({
|
||||||
|
'attributes': ReadOnlyDict({
|
||||||
|
'friendly_name': 'Gemini 1.5 Pro',
|
||||||
|
'supported_features': <AITaskEntityFeature: 1>,
|
||||||
|
}),
|
||||||
|
'context': <ANY>,
|
||||||
|
'entity_id': 'ai_task.gemini_1_5_pro',
|
||||||
|
'last_changed': <ANY>,
|
||||||
|
'last_reported': <ANY>,
|
||||||
|
'last_updated': <ANY>,
|
||||||
|
'state': 'unknown',
|
||||||
|
})
|
||||||
|
# ---
|
210
tests/components/open_router/test_ai_task.py
Normal file
210
tests/components/open_router/test_ai_task.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
"""Test AI Task structured data generation."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from openai.types import CompletionUsage
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||||
|
from openai.types.chat.chat_completion import Choice
|
||||||
|
import pytest
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import ai_task
|
||||||
|
from homeassistant.const import Platform
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import entity_registry as er, selector
|
||||||
|
|
||||||
|
from . import setup_integration
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry, snapshot_platform
|
||||||
|
|
||||||
|
|
||||||
|
async def test_all_entities(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test all entities."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.open_router.PLATFORMS",
|
||||||
|
[Platform.AI_TASK],
|
||||||
|
):
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task data generation."""
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
entity_id = "ai_task.gemini_1_5_pro"
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=ChatCompletion(
|
||||||
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content="The test data",
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1700000000,
|
||||||
|
model="x-ai/grok-3",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Generate test data",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.data == "The test data"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_structured_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task structured data generation."""
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=ChatCompletion(
|
||||||
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content='{"characters": ["Mario", "Luigi"]}',
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1700000000,
|
||||||
|
model="x-ai/grok-3",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id="ai_task.gemini_1_5_pro",
|
||||||
|
instructions="Generate test data",
|
||||||
|
structure=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("characters"): selector.selector(
|
||||||
|
{
|
||||||
|
"text": {
|
||||||
|
"multiple": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||||
|
assert mock_openai_client.chat.completions.create.call_args_list[0][1][
|
||||||
|
"response_format"
|
||||||
|
] == {
|
||||||
|
"json_schema": {
|
||||||
|
"name": "Test Task",
|
||||||
|
"schema": {
|
||||||
|
"properties": {
|
||||||
|
"characters": {
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"type": "array",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["characters"],
|
||||||
|
"type": "object",
|
||||||
|
},
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
"type": "json_schema",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_invalid_structured_data(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task with invalid JSON response."""
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=ChatCompletion(
|
||||||
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content="INVALID JSON RESPONSE",
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1700000000,
|
||||||
|
model="x-ai/grok-3",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match="Error with OpenRouter structured response"
|
||||||
|
):
|
||||||
|
await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id="ai_task.gemini_1_5_pro",
|
||||||
|
instructions="Generate test data",
|
||||||
|
structure=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("characters"): selector.selector(
|
||||||
|
{
|
||||||
|
"text": {
|
||||||
|
"multiple": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
@ -110,9 +110,6 @@ async def test_create_conversation_agent(
|
|||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test creating a conversation agent."""
|
"""Test creating a conversation agent."""
|
||||||
|
|
||||||
mock_config_entry.add_to_hass(hass)
|
|
||||||
|
|
||||||
await setup_integration(hass, mock_config_entry)
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
result = await hass.config_entries.subentries.async_init(
|
result = await hass.config_entries.subentries.async_init(
|
||||||
@ -152,9 +149,6 @@ async def test_create_conversation_agent_no_control(
|
|||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test creating a conversation agent without control over the LLM API."""
|
"""Test creating a conversation agent without control over the LLM API."""
|
||||||
|
|
||||||
mock_config_entry.add_to_hass(hass)
|
|
||||||
|
|
||||||
await setup_integration(hass, mock_config_entry)
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
result = await hass.config_entries.subentries.async_init(
|
result = await hass.config_entries.subentries.async_init(
|
||||||
@ -184,3 +178,63 @@ async def test_create_conversation_agent_no_control(
|
|||||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||||
CONF_PROMPT: "you are an assistant",
|
CONF_PROMPT: "you are an assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_ai_task(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_open_router_client: AsyncMock,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating an AI Task."""
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "ai_task_data"),
|
||||||
|
context={"source": SOURCE_USER},
|
||||||
|
)
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert not result["errors"]
|
||||||
|
assert result["step_id"] == "user"
|
||||||
|
|
||||||
|
assert result["data_schema"].schema["model"].config["options"] == [
|
||||||
|
{"value": "openai/gpt-4", "label": "OpenAI: GPT-4"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{CONF_MODEL: "openai/gpt-4"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result["data"] == {CONF_MODEL: "openai/gpt-4"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"subentry_type",
|
||||||
|
["conversation", "ai_task_data"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("exception", "reason"),
|
||||||
|
[(OpenRouterError("exception"), "cannot_connect"), (Exception, "unknown")],
|
||||||
|
)
|
||||||
|
async def test_subentry_exceptions(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_open_router_client: AsyncMock,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
subentry_type: str,
|
||||||
|
exception: Exception,
|
||||||
|
reason: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test subentry flow exceptions."""
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
mock_open_router_client.get_models.side_effect = exception
|
||||||
|
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, subentry_type),
|
||||||
|
context={"source": SOURCE_USER},
|
||||||
|
)
|
||||||
|
assert result["type"] is FlowResultType.ABORT
|
||||||
|
assert result["reason"] == reason
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Tests for the OpenRouter integration."""
|
"""Tests for the OpenRouter integration."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from openai.types import CompletionUsage
|
from openai.types import CompletionUsage
|
||||||
@ -15,6 +15,7 @@ import pytest
|
|||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers import entity_registry as er, intent
|
from homeassistant.helpers import entity_registry as er, intent
|
||||||
|
|
||||||
@ -40,7 +41,11 @@ async def test_all_entities(
|
|||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test all entities."""
|
"""Test all entities."""
|
||||||
await setup_integration(hass, mock_config_entry)
|
with patch(
|
||||||
|
"homeassistant.components.open_router.PLATFORMS",
|
||||||
|
[Platform.CONVERSATION],
|
||||||
|
):
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user