mirror of
https://github.com/home-assistant/core.git
synced 2025-07-31 17:18:23 +00:00
Add AI Task to OpenRouter (#149275)
This commit is contained in:
parent
223c34056d
commit
1b58809655
@ -12,7 +12,7 @@ from homeassistant.helpers.httpx_client import get_async_client
|
||||
|
||||
from .const import LOGGER
|
||||
|
||||
PLATFORMS = [Platform.CONVERSATION]
|
||||
PLATFORMS = [Platform.AI_TASK, Platform.CONVERSATION]
|
||||
|
||||
type OpenRouterConfigEntry = ConfigEntry[AsyncOpenAI]
|
||||
|
||||
|
75
homeassistant/components/open_router/ai_task.py
Normal file
75
homeassistant/components/open_router/ai_task.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""AI Task integration for OpenRouter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from . import OpenRouterConfigEntry
|
||||
from .entity import OpenRouterEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: OpenRouterConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AI Task entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "ai_task_data":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[OpenRouterAITaskEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class OpenRouterAITaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
OpenRouterEntity,
|
||||
):
|
||||
"""OpenRouter AI Task entity."""
|
||||
|
||||
_attr_name = None
|
||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: ai_task.GenDataTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
await self._async_handle_chat_log(chat_log, task.name, task.structure)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
raise HomeAssistantError(
|
||||
"Last content in chat log is not an AssistantContent"
|
||||
)
|
||||
|
||||
text = chat_log.content[-1].content or ""
|
||||
|
||||
if not task.structure:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=text,
|
||||
)
|
||||
try:
|
||||
data = json_loads(text)
|
||||
except JSONDecodeError as err:
|
||||
raise HomeAssistantError(
|
||||
"Error with OpenRouter structured response"
|
||||
) from err
|
||||
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
)
|
@ -5,7 +5,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from python_open_router import Model, OpenRouterClient, OpenRouterError
|
||||
from python_open_router import (
|
||||
Model,
|
||||
OpenRouterClient,
|
||||
OpenRouterError,
|
||||
SupportedParameter,
|
||||
)
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import (
|
||||
@ -43,7 +48,10 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
cls, config_entry: ConfigEntry
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this handler."""
|
||||
return {"conversation": ConversationFlowHandler}
|
||||
return {
|
||||
"conversation": ConversationFlowHandler,
|
||||
"ai_task_data": AITaskDataFlowHandler,
|
||||
}
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@ -78,13 +86,26 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
)
|
||||
|
||||
|
||||
class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
"""Handle subentry flow."""
|
||||
class OpenRouterSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Handle subentry flow for OpenRouter."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the subentry flow."""
|
||||
self.models: dict[str, Model] = {}
|
||||
|
||||
async def _get_models(self) -> None:
|
||||
"""Fetch models from OpenRouter."""
|
||||
entry = self._get_entry()
|
||||
client = OpenRouterClient(
|
||||
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
|
||||
)
|
||||
models = await client.get_models()
|
||||
self.models = {model.id: model for model in models}
|
||||
|
||||
|
||||
class ConversationFlowHandler(OpenRouterSubentryFlowHandler):
|
||||
"""Handle subentry flow."""
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
@ -95,14 +116,16 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
return self.async_create_entry(
|
||||
title=self.models[user_input[CONF_MODEL]].name, data=user_input
|
||||
)
|
||||
entry = self._get_entry()
|
||||
client = OpenRouterClient(
|
||||
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
|
||||
)
|
||||
models = await client.get_models()
|
||||
self.models = {model.id: model for model in models}
|
||||
try:
|
||||
await self._get_models()
|
||||
except OpenRouterError:
|
||||
return self.async_abort(reason="cannot_connect")
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
return self.async_abort(reason="unknown")
|
||||
options = [
|
||||
SelectOptionDict(value=model.id, label=model.name) for model in models
|
||||
SelectOptionDict(value=model.id, label=model.name)
|
||||
for model in self.models.values()
|
||||
]
|
||||
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
@ -138,3 +161,40 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AITaskDataFlowHandler(OpenRouterSubentryFlowHandler):
|
||||
"""Handle subentry flow."""
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""User flow to create a sensor subentry."""
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title=self.models[user_input[CONF_MODEL]].name, data=user_input
|
||||
)
|
||||
try:
|
||||
await self._get_models()
|
||||
except OpenRouterError:
|
||||
return self.async_abort(reason="cannot_connect")
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
return self.async_abort(reason="unknown")
|
||||
options = [
|
||||
SelectOptionDict(value=model.id, label=model.name)
|
||||
for model in self.models.values()
|
||||
if SupportedParameter.STRUCTURED_OUTPUTS in model.supported_parameters
|
||||
]
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_MODEL): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
|
||||
),
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -20,6 +20,8 @@ async def async_setup_entry(
|
||||
) -> None:
|
||||
"""Set up conversation entities."""
|
||||
for subentry_id, subentry in config_entry.subentries.items():
|
||||
if subentry.subentry_type != "conversation":
|
||||
continue
|
||||
async_add_entities(
|
||||
[OpenRouterConversationEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry_id,
|
||||
|
@ -4,10 +4,9 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import openai
|
||||
from openai import NOT_GIVEN
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
@ -19,7 +18,9 @@ from openai.types.chat import (
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema
|
||||
from openai.types.shared_params.response_format_json_schema import JSONSchema
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
@ -36,6 +37,50 @@ from .const import DOMAIN, LOGGER
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
def _adjust_schema(schema: dict[str, Any]) -> None:
|
||||
"""Adjust the schema to be compatible with OpenRouter API."""
|
||||
if schema["type"] == "object":
|
||||
if "properties" not in schema:
|
||||
return
|
||||
|
||||
if "required" not in schema:
|
||||
schema["required"] = []
|
||||
|
||||
# Ensure all properties are required
|
||||
for prop, prop_info in schema["properties"].items():
|
||||
_adjust_schema(prop_info)
|
||||
if prop not in schema["required"]:
|
||||
prop_info["type"] = [prop_info["type"], "null"]
|
||||
schema["required"].append(prop)
|
||||
|
||||
elif schema["type"] == "array":
|
||||
if "items" not in schema:
|
||||
return
|
||||
|
||||
_adjust_schema(schema["items"])
|
||||
|
||||
|
||||
def _format_structured_output(
|
||||
name: str, schema: vol.Schema, llm_api: llm.APIInstance | None
|
||||
) -> JSONSchema:
|
||||
"""Format the schema to be compatible with OpenRouter API."""
|
||||
result: JSONSchema = {
|
||||
"name": name,
|
||||
"strict": True,
|
||||
}
|
||||
result_schema = convert(
|
||||
schema,
|
||||
custom_serializer=(
|
||||
llm_api.custom_serializer if llm_api else llm.selector_serializer
|
||||
),
|
||||
)
|
||||
|
||||
_adjust_schema(result_schema)
|
||||
|
||||
result["schema"] = result_schema
|
||||
return result
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool,
|
||||
custom_serializer: Callable[[Any], Any] | None,
|
||||
@ -136,9 +181,24 @@ class OpenRouterEntity(Entity):
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
|
||||
async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None:
|
||||
async def _async_handle_chat_log(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
structure_name: str | None = None,
|
||||
structure: vol.Schema | None = None,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
|
||||
model_args = {
|
||||
"model": self.model,
|
||||
"user": chat_log.conversation_id,
|
||||
"extra_headers": {
|
||||
"X-Title": "Home Assistant",
|
||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||
},
|
||||
"extra_body": {"require_parameters": True},
|
||||
}
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
@ -146,33 +206,37 @@ class OpenRouterEntity(Entity):
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
messages = [
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
model_args["messages"] = [
|
||||
m
|
||||
for content in chat_log.content
|
||||
if (m := _convert_content_to_chat_message(content))
|
||||
]
|
||||
|
||||
if structure:
|
||||
if TYPE_CHECKING:
|
||||
assert structure_name is not None
|
||||
model_args["response_format"] = ResponseFormatJSONSchema(
|
||||
type="json_schema",
|
||||
json_schema=_format_structured_output(
|
||||
structure_name, structure, chat_log.llm_api
|
||||
),
|
||||
)
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
user=chat_log.conversation_id,
|
||||
extra_headers={
|
||||
"X-Title": "Home Assistant",
|
||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||
},
|
||||
)
|
||||
result = await client.chat.completions.create(**model_args)
|
||||
except openai.OpenAIError as err:
|
||||
LOGGER.error("Error talking to API: %s", err)
|
||||
raise HomeAssistantError("Error talking to API") from err
|
||||
|
||||
result_message = result.choices[0].message
|
||||
|
||||
messages.extend(
|
||||
model_args["messages"].extend(
|
||||
[
|
||||
msg
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
|
@ -37,7 +37,28 @@
|
||||
"initiate_flow": {
|
||||
"user": "Add conversation agent"
|
||||
},
|
||||
"entry_type": "Conversation agent"
|
||||
"entry_type": "Conversation agent",
|
||||
"abort": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
}
|
||||
},
|
||||
"ai_task_data": {
|
||||
"step": {
|
||||
"user": {
|
||||
"data": {
|
||||
"model": "[%key:component::open_router::config_subentries::conversation::step::user::data::model%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"initiate_flow": {
|
||||
"user": "Add Generate data with AI service"
|
||||
},
|
||||
"entry_type": "Generate data with AI service",
|
||||
"abort": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -49,9 +49,19 @@ def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
||||
return res
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_task_data_subentry_data() -> dict[str, Any]:
|
||||
"""Mock AI task subentry data."""
|
||||
return {
|
||||
CONF_MODEL: "google/gemini-1.5-pro",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(
|
||||
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
|
||||
hass: HomeAssistant,
|
||||
conversation_subentry_data: dict[str, Any],
|
||||
ai_task_data_subentry_data: dict[str, Any],
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry."""
|
||||
return MockConfigEntry(
|
||||
@ -67,7 +77,14 @@ def mock_config_entry(
|
||||
subentry_type="conversation",
|
||||
title="GPT-3.5 Turbo",
|
||||
unique_id=None,
|
||||
)
|
||||
),
|
||||
ConfigSubentryData(
|
||||
data=ai_task_data_subentry_data,
|
||||
subentry_id="ABCDEG",
|
||||
subentry_type="ai_task_data",
|
||||
title="Gemini 1.5 Pro",
|
||||
unique_id=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -85,6 +85,7 @@
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"structured_outputs",
|
||||
"response_format"
|
||||
]
|
||||
}
|
||||
|
53
tests/components/open_router/snapshots/test_ai_task.ambr
Normal file
53
tests/components/open_router/snapshots/test_ai_task.ambr
Normal file
@ -0,0 +1,53 @@
|
||||
# serializer version: 1
|
||||
# name: test_all_entities[ai_task.gemini_1_5_pro-entry]
|
||||
EntityRegistryEntrySnapshot({
|
||||
'aliases': set({
|
||||
}),
|
||||
'area_id': None,
|
||||
'capabilities': None,
|
||||
'config_entry_id': <ANY>,
|
||||
'config_subentry_id': <ANY>,
|
||||
'device_class': None,
|
||||
'device_id': <ANY>,
|
||||
'disabled_by': None,
|
||||
'domain': 'ai_task',
|
||||
'entity_category': None,
|
||||
'entity_id': 'ai_task.gemini_1_5_pro',
|
||||
'has_entity_name': True,
|
||||
'hidden_by': None,
|
||||
'icon': None,
|
||||
'id': <ANY>,
|
||||
'labels': set({
|
||||
}),
|
||||
'name': None,
|
||||
'options': dict({
|
||||
'conversation': dict({
|
||||
'should_expose': False,
|
||||
}),
|
||||
}),
|
||||
'original_device_class': None,
|
||||
'original_icon': None,
|
||||
'original_name': None,
|
||||
'platform': 'open_router',
|
||||
'previous_unique_id': None,
|
||||
'suggested_object_id': None,
|
||||
'supported_features': <AITaskEntityFeature: 1>,
|
||||
'translation_key': None,
|
||||
'unique_id': 'ABCDEG',
|
||||
'unit_of_measurement': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_all_entities[ai_task.gemini_1_5_pro-state]
|
||||
StateSnapshot({
|
||||
'attributes': ReadOnlyDict({
|
||||
'friendly_name': 'Gemini 1.5 Pro',
|
||||
'supported_features': <AITaskEntityFeature: 1>,
|
||||
}),
|
||||
'context': <ANY>,
|
||||
'entity_id': 'ai_task.gemini_1_5_pro',
|
||||
'last_changed': <ANY>,
|
||||
'last_reported': <ANY>,
|
||||
'last_updated': <ANY>,
|
||||
'state': 'unknown',
|
||||
})
|
||||
# ---
|
210
tests/components/open_router/test_ai_task.py
Normal file
210
tests/components/open_router/test_ai_task.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Test AI Task structured data generation."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ai_task
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity_registry as er, selector
|
||||
|
||||
from . import setup_integration
|
||||
|
||||
from tests.common import MockConfigEntry, snapshot_platform
|
||||
|
||||
|
||||
async def test_all_entities(
|
||||
hass: HomeAssistant,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test all entities."""
|
||||
with patch(
|
||||
"homeassistant.components.open_router.PLATFORMS",
|
||||
[Platform.AI_TASK],
|
||||
):
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
||||
|
||||
|
||||
async def test_generate_data(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_openai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test AI Task data generation."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
entity_id = "ai_task.gemini_1_5_pro"
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="The test data",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="x-ai/grok-3",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Generate test data",
|
||||
)
|
||||
|
||||
assert result.data == "The test data"
|
||||
|
||||
|
||||
async def test_generate_structured_data(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_openai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test AI Task structured data generation."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content='{"characters": ["Mario", "Luigi"]}',
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="x-ai/grok-3",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.gemini_1_5_pro",
|
||||
instructions="Generate test data",
|
||||
structure=vol.Schema(
|
||||
{
|
||||
vol.Required("characters"): selector.selector(
|
||||
{
|
||||
"text": {
|
||||
"multiple": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||
assert mock_openai_client.chat.completions.create.call_args_list[0][1][
|
||||
"response_format"
|
||||
] == {
|
||||
"json_schema": {
|
||||
"name": "Test Task",
|
||||
"schema": {
|
||||
"properties": {
|
||||
"characters": {
|
||||
"items": {"type": "string"},
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["characters"],
|
||||
"type": "object",
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
"type": "json_schema",
|
||||
}
|
||||
|
||||
|
||||
async def test_generate_invalid_structured_data(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_openai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test AI Task with invalid JSON response."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="INVALID JSON RESPONSE",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="x-ai/grok-3",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="Error with OpenRouter structured response"
|
||||
):
|
||||
await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.gemini_1_5_pro",
|
||||
instructions="Generate test data",
|
||||
structure=vol.Schema(
|
||||
{
|
||||
vol.Required("characters"): selector.selector(
|
||||
{
|
||||
"text": {
|
||||
"multiple": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
@ -110,9 +110,6 @@ async def test_create_conversation_agent(
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating a conversation agent."""
|
||||
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
@ -152,9 +149,6 @@ async def test_create_conversation_agent_no_control(
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating a conversation agent without control over the LLM API."""
|
||||
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
@ -184,3 +178,63 @@ async def test_create_conversation_agent_no_control(
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
}
|
||||
|
||||
|
||||
async def test_create_ai_task(
|
||||
hass: HomeAssistant,
|
||||
mock_open_router_client: AsyncMock,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating an AI Task."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert not result["errors"]
|
||||
assert result["step_id"] == "user"
|
||||
|
||||
assert result["data_schema"].schema["model"].config["options"] == [
|
||||
{"value": "openai/gpt-4", "label": "OpenAI: GPT-4"},
|
||||
]
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_MODEL: "openai/gpt-4"},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {CONF_MODEL: "openai/gpt-4"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"subentry_type",
|
||||
["conversation", "ai_task_data"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "reason"),
|
||||
[(OpenRouterError("exception"), "cannot_connect"), (Exception, "unknown")],
|
||||
)
|
||||
async def test_subentry_exceptions(
|
||||
hass: HomeAssistant,
|
||||
mock_open_router_client: AsyncMock,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
subentry_type: str,
|
||||
exception: Exception,
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""Test subentry flow exceptions."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
mock_open_router_client.get_models.side_effect = exception
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, subentry_type),
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == reason
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Tests for the OpenRouter integration."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from openai.types import CompletionUsage
|
||||
@ -15,6 +15,7 @@ import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er, intent
|
||||
|
||||
@ -40,7 +41,11 @@ async def test_all_entities(
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test all entities."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
with patch(
|
||||
"homeassistant.components.open_router.PLATFORMS",
|
||||
[Platform.CONVERSATION],
|
||||
):
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user