Add AI Task to OpenRouter (#149275)

This commit is contained in:
Joost Lekkerkerker 2025-07-30 16:01:44 +02:00 committed by GitHub
parent 223c34056d
commit 1b58809655
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 601 additions and 39 deletions

View File

@ -12,7 +12,7 @@ from homeassistant.helpers.httpx_client import get_async_client
from .const import LOGGER
PLATFORMS = [Platform.CONVERSATION]
PLATFORMS = [Platform.AI_TASK, Platform.CONVERSATION]
type OpenRouterConfigEntry = ConfigEntry[AsyncOpenAI]

View File

@ -0,0 +1,75 @@
"""AI Task integration for OpenRouter."""
from __future__ import annotations
from json import JSONDecodeError
import logging
from homeassistant.components import ai_task, conversation
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from . import OpenRouterConfigEntry
from .entity import OpenRouterEntity
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: OpenRouterConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue
async_add_entities(
[OpenRouterAITaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class OpenRouterAITaskEntity(
ai_task.AITaskEntity,
OpenRouterEntity,
):
"""OpenRouter AI Task entity."""
_attr_name = None
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.name, task.structure)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)
text = chat_log.content[-1].content or ""
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
try:
data = json_loads(text)
except JSONDecodeError as err:
raise HomeAssistantError(
"Error with OpenRouter structured response"
) from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)

View File

@ -5,7 +5,12 @@ from __future__ import annotations
import logging
from typing import Any
from python_open_router import Model, OpenRouterClient, OpenRouterError
from python_open_router import (
Model,
OpenRouterClient,
OpenRouterError,
SupportedParameter,
)
import voluptuous as vol
from homeassistant.config_entries import (
@ -43,7 +48,10 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this handler."""
return {"conversation": ConversationFlowHandler}
return {
"conversation": ConversationFlowHandler,
"ai_task_data": AITaskDataFlowHandler,
}
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@ -78,13 +86,26 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
)
class ConversationFlowHandler(ConfigSubentryFlow):
"""Handle subentry flow."""
class OpenRouterSubentryFlowHandler(ConfigSubentryFlow):
"""Handle subentry flow for OpenRouter."""
def __init__(self) -> None:
"""Initialize the subentry flow."""
self.models: dict[str, Model] = {}
async def _get_models(self) -> None:
"""Fetch models from OpenRouter."""
entry = self._get_entry()
client = OpenRouterClient(
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
)
models = await client.get_models()
self.models = {model.id: model for model in models}
class ConversationFlowHandler(OpenRouterSubentryFlowHandler):
"""Handle subentry flow."""
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
@ -95,14 +116,16 @@ class ConversationFlowHandler(ConfigSubentryFlow):
return self.async_create_entry(
title=self.models[user_input[CONF_MODEL]].name, data=user_input
)
entry = self._get_entry()
client = OpenRouterClient(
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
)
models = await client.get_models()
self.models = {model.id: model for model in models}
try:
await self._get_models()
except OpenRouterError:
return self.async_abort(reason="cannot_connect")
except Exception:
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
options = [
SelectOptionDict(value=model.id, label=model.name) for model in models
SelectOptionDict(value=model.id, label=model.name)
for model in self.models.values()
]
hass_apis: list[SelectOptionDict] = [
@ -138,3 +161,40 @@ class ConversationFlowHandler(ConfigSubentryFlow):
}
),
)
class AITaskDataFlowHandler(OpenRouterSubentryFlowHandler):
"""Handle subentry flow."""
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""User flow to create a sensor subentry."""
if user_input is not None:
return self.async_create_entry(
title=self.models[user_input[CONF_MODEL]].name, data=user_input
)
try:
await self._get_models()
except OpenRouterError:
return self.async_abort(reason="cannot_connect")
except Exception:
_LOGGER.exception("Unexpected exception")
return self.async_abort(reason="unknown")
options = [
SelectOptionDict(value=model.id, label=model.name)
for model in self.models.values()
if SupportedParameter.STRUCTURED_OUTPUTS in model.supported_parameters
]
return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
{
vol.Required(CONF_MODEL): SelectSelector(
SelectSelectorConfig(
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
),
),
}
),
)

View File

@ -20,6 +20,8 @@ async def async_setup_entry(
) -> None:
"""Set up conversation entities."""
for subentry_id, subentry in config_entry.subentries.items():
if subentry.subentry_type != "conversation":
continue
async_add_entities(
[OpenRouterConversationEntity(config_entry, subentry)],
config_subentry_id=subentry_id,

View File

@ -4,10 +4,9 @@ from __future__ import annotations
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
import openai
from openai import NOT_GIVEN
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
@ -19,7 +18,9 @@ from openai.types.chat import (
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema
from openai.types.shared_params.response_format_json_schema import JSONSchema
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import conversation
@ -36,6 +37,50 @@ from .const import DOMAIN, LOGGER
MAX_TOOL_ITERATIONS = 10
def _adjust_schema(schema: dict[str, Any]) -> None:
"""Adjust the schema to be compatible with OpenRouter API."""
if schema["type"] == "object":
if "properties" not in schema:
return
if "required" not in schema:
schema["required"] = []
# Ensure all properties are required
for prop, prop_info in schema["properties"].items():
_adjust_schema(prop_info)
if prop not in schema["required"]:
prop_info["type"] = [prop_info["type"], "null"]
schema["required"].append(prop)
elif schema["type"] == "array":
if "items" not in schema:
return
_adjust_schema(schema["items"])
def _format_structured_output(
name: str, schema: vol.Schema, llm_api: llm.APIInstance | None
) -> JSONSchema:
"""Format the schema to be compatible with OpenRouter API."""
result: JSONSchema = {
"name": name,
"strict": True,
}
result_schema = convert(
schema,
custom_serializer=(
llm_api.custom_serializer if llm_api else llm.selector_serializer
),
)
_adjust_schema(result_schema)
result["schema"] = result_schema
return result
def _format_tool(
tool: llm.Tool,
custom_serializer: Callable[[Any], Any] | None,
@ -136,9 +181,24 @@ class OpenRouterEntity(Entity):
entry_type=dr.DeviceEntryType.SERVICE,
)
async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None:
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
structure_name: str | None = None,
structure: vol.Schema | None = None,
) -> None:
"""Generate an answer for the chat log."""
model_args = {
"model": self.model,
"user": chat_log.conversation_id,
"extra_headers": {
"X-Title": "Home Assistant",
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
},
"extra_body": {"require_parameters": True},
}
tools: list[ChatCompletionToolParam] | None = None
if chat_log.llm_api:
tools = [
@ -146,33 +206,37 @@ class OpenRouterEntity(Entity):
for tool in chat_log.llm_api.tools
]
messages = [
if tools:
model_args["tools"] = tools
model_args["messages"] = [
m
for content in chat_log.content
if (m := _convert_content_to_chat_message(content))
]
if structure:
if TYPE_CHECKING:
assert structure_name is not None
model_args["response_format"] = ResponseFormatJSONSchema(
type="json_schema",
json_schema=_format_structured_output(
structure_name, structure, chat_log.llm_api
),
)
client = self.entry.runtime_data
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=self.model,
messages=messages,
tools=tools or NOT_GIVEN,
user=chat_log.conversation_id,
extra_headers={
"X-Title": "Home Assistant",
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
},
)
result = await client.chat.completions.create(**model_args)
except openai.OpenAIError as err:
LOGGER.error("Error talking to API: %s", err)
raise HomeAssistantError("Error talking to API") from err
result_message = result.choices[0].message
messages.extend(
model_args["messages"].extend(
[
msg
async for content in chat_log.async_add_delta_content_stream(

View File

@ -37,7 +37,28 @@
"initiate_flow": {
"user": "Add conversation agent"
},
"entry_type": "Conversation agent"
"entry_type": "Conversation agent",
"abort": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"unknown": "[%key:common::config_flow::error::unknown%]"
}
},
"ai_task_data": {
"step": {
"user": {
"data": {
"model": "[%key:component::open_router::config_subentries::conversation::step::user::data::model%]"
}
}
},
"initiate_flow": {
"user": "Add Generate data with AI service"
},
"entry_type": "Generate data with AI service",
"abort": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"unknown": "[%key:common::config_flow::error::unknown%]"
}
}
}
}

View File

@ -49,9 +49,19 @@ def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
return res
@pytest.fixture
def ai_task_data_subentry_data() -> dict[str, Any]:
"""Mock AI task subentry data."""
return {
CONF_MODEL: "google/gemini-1.5-pro",
}
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
hass: HomeAssistant,
conversation_subentry_data: dict[str, Any],
ai_task_data_subentry_data: dict[str, Any],
) -> MockConfigEntry:
"""Mock a config entry."""
return MockConfigEntry(
@ -67,7 +77,14 @@ def mock_config_entry(
subentry_type="conversation",
title="GPT-3.5 Turbo",
unique_id=None,
)
),
ConfigSubentryData(
data=ai_task_data_subentry_data,
subentry_id="ABCDEG",
subentry_type="ai_task_data",
title="Gemini 1.5 Pro",
unique_id=None,
),
],
)

View File

@ -85,6 +85,7 @@
"logit_bias",
"logprobs",
"top_logprobs",
"structured_outputs",
"response_format"
]
}

View File

@ -0,0 +1,53 @@
# serializer version: 1
# name: test_all_entities[ai_task.gemini_1_5_pro-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'config_subentry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'ai_task',
'entity_category': None,
'entity_id': 'ai_task.gemini_1_5_pro',
'has_entity_name': True,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
'conversation': dict({
'should_expose': False,
}),
}),
'original_device_class': None,
'original_icon': None,
'original_name': None,
'platform': 'open_router',
'previous_unique_id': None,
'suggested_object_id': None,
'supported_features': <AITaskEntityFeature: 1>,
'translation_key': None,
'unique_id': 'ABCDEG',
'unit_of_measurement': None,
})
# ---
# name: test_all_entities[ai_task.gemini_1_5_pro-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'friendly_name': 'Gemini 1.5 Pro',
'supported_features': <AITaskEntityFeature: 1>,
}),
'context': <ANY>,
'entity_id': 'ai_task.gemini_1_5_pro',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'unknown',
})
# ---

View File

@ -0,0 +1,210 @@
"""Test AI Task structured data generation."""
from unittest.mock import AsyncMock, patch
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import ai_task
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from . import setup_integration
from tests.common import MockConfigEntry, snapshot_platform
async def test_all_entities(
hass: HomeAssistant,
snapshot: SnapshotAssertion,
mock_openai_client: AsyncMock,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test all entities."""
with patch(
"homeassistant.components.open_router.PLATFORMS",
[Platform.AI_TASK],
):
await setup_integration(hass, mock_config_entry)
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
async def test_generate_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_openai_client: AsyncMock,
) -> None:
"""Test AI Task data generation."""
await setup_integration(hass, mock_config_entry)
entity_id = "ai_task.gemini_1_5_pro"
mock_openai_client.chat.completions.create = AsyncMock(
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="The test data",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="x-ai/grok-3",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
)
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
)
assert result.data == "The test data"
async def test_generate_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_openai_client: AsyncMock,
) -> None:
"""Test AI Task structured data generation."""
await setup_integration(hass, mock_config_entry)
mock_openai_client.chat.completions.create = AsyncMock(
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content='{"characters": ["Mario", "Luigi"]}',
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="x-ai/grok-3",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
)
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.gemini_1_5_pro",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)
assert result.data == {"characters": ["Mario", "Luigi"]}
assert mock_openai_client.chat.completions.create.call_args_list[0][1][
"response_format"
] == {
"json_schema": {
"name": "Test Task",
"schema": {
"properties": {
"characters": {
"items": {"type": "string"},
"type": "array",
}
},
"required": ["characters"],
"type": "object",
},
"strict": True,
},
"type": "json_schema",
}
async def test_generate_invalid_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_openai_client: AsyncMock,
) -> None:
"""Test AI Task with invalid JSON response."""
await setup_integration(hass, mock_config_entry)
mock_openai_client.chat.completions.create = AsyncMock(
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="INVALID JSON RESPONSE",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="x-ai/grok-3",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
)
with pytest.raises(
HomeAssistantError, match="Error with OpenRouter structured response"
):
await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.gemini_1_5_pro",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)

View File

@ -110,9 +110,6 @@ async def test_create_conversation_agent(
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation agent."""
mock_config_entry.add_to_hass(hass)
await setup_integration(hass, mock_config_entry)
result = await hass.config_entries.subentries.async_init(
@ -152,9 +149,6 @@ async def test_create_conversation_agent_no_control(
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation agent without control over the LLM API."""
mock_config_entry.add_to_hass(hass)
await setup_integration(hass, mock_config_entry)
result = await hass.config_entries.subentries.async_init(
@ -184,3 +178,63 @@ async def test_create_conversation_agent_no_control(
CONF_MODEL: "openai/gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
}
async def test_create_ai_task(
hass: HomeAssistant,
mock_open_router_client: AsyncMock,
mock_openai_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating an AI Task."""
await setup_integration(hass, mock_config_entry)
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM
assert not result["errors"]
assert result["step_id"] == "user"
assert result["data_schema"].schema["model"].config["options"] == [
{"value": "openai/gpt-4", "label": "OpenAI: GPT-4"},
]
result = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_MODEL: "openai/gpt-4"},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == {CONF_MODEL: "openai/gpt-4"}
@pytest.mark.parametrize(
"subentry_type",
["conversation", "ai_task_data"],
)
@pytest.mark.parametrize(
("exception", "reason"),
[(OpenRouterError("exception"), "cannot_connect"), (Exception, "unknown")],
)
async def test_subentry_exceptions(
hass: HomeAssistant,
mock_open_router_client: AsyncMock,
mock_openai_client: AsyncMock,
mock_config_entry: MockConfigEntry,
subentry_type: str,
exception: Exception,
reason: str,
) -> None:
"""Test subentry flow exceptions."""
await setup_integration(hass, mock_config_entry)
mock_open_router_client.get_models.side_effect = exception
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, subentry_type),
context={"source": SOURCE_USER},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == reason

View File

@ -1,6 +1,6 @@
"""Tests for the OpenRouter integration."""
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch
from freezegun import freeze_time
from openai.types import CompletionUsage
@ -15,6 +15,7 @@ import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.const import Platform
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import entity_registry as er, intent
@ -40,7 +41,11 @@ async def test_all_entities(
entity_registry: er.EntityRegistry,
) -> None:
"""Test all entities."""
await setup_integration(hass, mock_config_entry)
with patch(
"homeassistant.components.open_router.PLATFORMS",
[Platform.CONVERSATION],
):
await setup_integration(hass, mock_config_entry)
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)