From 1b58809655bebe1f50159efc5670f3a5459696c6 Mon Sep 17 00:00:00 2001 From: Joost Lekkerkerker Date: Wed, 30 Jul 2025 16:01:44 +0200 Subject: [PATCH] Add AI Task to OpenRouter (#149275) --- .../components/open_router/__init__.py | 2 +- .../components/open_router/ai_task.py | 75 +++++++ .../components/open_router/config_flow.py | 82 ++++++- .../components/open_router/conversation.py | 2 + .../components/open_router/entity.py | 96 ++++++-- .../components/open_router/strings.json | 23 +- tests/components/open_router/conftest.py | 21 +- .../open_router/fixtures/models.json | 1 + .../open_router/snapshots/test_ai_task.ambr | 53 +++++ tests/components/open_router/test_ai_task.py | 210 ++++++++++++++++++ .../open_router/test_config_flow.py | 66 +++++- .../open_router/test_conversation.py | 9 +- 12 files changed, 601 insertions(+), 39 deletions(-) create mode 100644 homeassistant/components/open_router/ai_task.py create mode 100644 tests/components/open_router/snapshots/test_ai_task.ambr create mode 100644 tests/components/open_router/test_ai_task.py diff --git a/homeassistant/components/open_router/__init__.py b/homeassistant/components/open_router/__init__.py index 477fabca54c..9850f72f71d 100644 --- a/homeassistant/components/open_router/__init__.py +++ b/homeassistant/components/open_router/__init__.py @@ -12,7 +12,7 @@ from homeassistant.helpers.httpx_client import get_async_client from .const import LOGGER -PLATFORMS = [Platform.CONVERSATION] +PLATFORMS = [Platform.AI_TASK, Platform.CONVERSATION] type OpenRouterConfigEntry = ConfigEntry[AsyncOpenAI] diff --git a/homeassistant/components/open_router/ai_task.py b/homeassistant/components/open_router/ai_task.py new file mode 100644 index 00000000000..fa5d8d0f68e --- /dev/null +++ b/homeassistant/components/open_router/ai_task.py @@ -0,0 +1,75 @@ +"""AI Task integration for OpenRouter.""" + +from __future__ import annotations + +from json import JSONDecodeError +import logging + +from homeassistant.components import ai_task, conversation +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.util.json import json_loads + +from . import OpenRouterConfigEntry +from .entity import OpenRouterEntity + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: OpenRouterConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, +) -> None: + """Set up AI Task entities.""" + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "ai_task_data": + continue + + async_add_entities( + [OpenRouterAITaskEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) + + +class OpenRouterAITaskEntity( + ai_task.AITaskEntity, + OpenRouterEntity, +): + """OpenRouter AI Task entity.""" + + _attr_name = None + _attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA + + async def _async_generate_data( + self, + task: ai_task.GenDataTask, + chat_log: conversation.ChatLog, + ) -> ai_task.GenDataTaskResult: + """Handle a generate data task.""" + await self._async_handle_chat_log(chat_log, task.name, task.structure) + + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + raise HomeAssistantError( + "Last content in chat log is not an AssistantContent" + ) + + text = chat_log.content[-1].content or "" + + if not task.structure: + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=text, + ) + try: + data = json_loads(text) + except JSONDecodeError as err: + raise HomeAssistantError( + "Error with OpenRouter structured response" + ) from err + + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=data, + ) diff --git a/homeassistant/components/open_router/config_flow.py b/homeassistant/components/open_router/config_flow.py index 96f3769575b..2afe2129a4c 100644 --- a/homeassistant/components/open_router/config_flow.py +++ b/homeassistant/components/open_router/config_flow.py @@ -5,7 +5,12 @@ from __future__ import annotations import logging from typing import Any -from python_open_router import Model, OpenRouterClient, OpenRouterError +from python_open_router import ( + Model, + OpenRouterClient, + OpenRouterError, + SupportedParameter, +) import voluptuous as vol from homeassistant.config_entries import ( @@ -43,7 +48,10 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN): cls, config_entry: ConfigEntry ) -> dict[str, type[ConfigSubentryFlow]]: """Return subentries supported by this handler.""" - return {"conversation": ConversationFlowHandler} + return { + "conversation": ConversationFlowHandler, + "ai_task_data": AITaskDataFlowHandler, + } async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -78,13 +86,26 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN): ) -class ConversationFlowHandler(ConfigSubentryFlow): - """Handle subentry flow.""" +class OpenRouterSubentryFlowHandler(ConfigSubentryFlow): + """Handle subentry flow for OpenRouter.""" def __init__(self) -> None: """Initialize the subentry flow.""" self.models: dict[str, Model] = {} + async def _get_models(self) -> None: + """Fetch models from OpenRouter.""" + entry = self._get_entry() + client = OpenRouterClient( + entry.data[CONF_API_KEY], async_get_clientsession(self.hass) + ) + models = await client.get_models() + self.models = {model.id: model for model in models} + + +class ConversationFlowHandler(OpenRouterSubentryFlowHandler): + """Handle subentry flow.""" + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: @@ -95,14 +116,16 @@ class ConversationFlowHandler(ConfigSubentryFlow): return self.async_create_entry( title=self.models[user_input[CONF_MODEL]].name, data=user_input ) - entry = self._get_entry() - client = OpenRouterClient( - entry.data[CONF_API_KEY], async_get_clientsession(self.hass) - ) - models = await client.get_models() - self.models = {model.id: model for model in models} + try: + await self._get_models() + except OpenRouterError: + return self.async_abort(reason="cannot_connect") + except Exception: + _LOGGER.exception("Unexpected exception") + return self.async_abort(reason="unknown") options = [ - SelectOptionDict(value=model.id, label=model.name) for model in models + SelectOptionDict(value=model.id, label=model.name) + for model in self.models.values() ] hass_apis: list[SelectOptionDict] = [ @@ -138,3 +161,40 @@ class ConversationFlowHandler(ConfigSubentryFlow): } ), ) + + +class AITaskDataFlowHandler(OpenRouterSubentryFlowHandler): + """Handle subentry flow.""" + + async def async_step_user( + self, user_input: dict[str, Any] | None = None + ) -> SubentryFlowResult: + """User flow to create a sensor subentry.""" + if user_input is not None: + return self.async_create_entry( + title=self.models[user_input[CONF_MODEL]].name, data=user_input + ) + try: + await self._get_models() + except OpenRouterError: + return self.async_abort(reason="cannot_connect") + except Exception: + _LOGGER.exception("Unexpected exception") + return self.async_abort(reason="unknown") + options = [ + SelectOptionDict(value=model.id, label=model.name) + for model in self.models.values() + if SupportedParameter.STRUCTURED_OUTPUTS in model.supported_parameters + ] + return self.async_show_form( + step_id="user", + data_schema=vol.Schema( + { + vol.Required(CONF_MODEL): SelectSelector( + SelectSelectorConfig( + options=options, mode=SelectSelectorMode.DROPDOWN, sort=True + ), + ), + } + ), + ) diff --git a/homeassistant/components/open_router/conversation.py b/homeassistant/components/open_router/conversation.py index 826931d3da7..3c185ecd77c 100644 --- a/homeassistant/components/open_router/conversation.py +++ b/homeassistant/components/open_router/conversation.py @@ -20,6 +20,8 @@ async def async_setup_entry( ) -> None: """Set up conversation entities.""" for subentry_id, subentry in config_entry.subentries.items(): + if subentry.subentry_type != "conversation": + continue async_add_entities( [OpenRouterConversationEntity(config_entry, subentry)], config_subentry_id=subentry_id, diff --git a/homeassistant/components/open_router/entity.py b/homeassistant/components/open_router/entity.py index e706656d377..ac01ec89704 100644 --- a/homeassistant/components/open_router/entity.py +++ b/homeassistant/components/open_router/entity.py @@ -4,10 +4,9 @@ from __future__ import annotations from collections.abc import AsyncGenerator, Callable import json -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import openai -from openai import NOT_GIVEN from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionMessage, @@ -19,7 +18,9 @@ from openai.types.chat import ( ChatCompletionUserMessageParam, ) from openai.types.chat.chat_completion_message_tool_call_param import Function -from openai.types.shared_params import FunctionDefinition +from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema +from openai.types.shared_params.response_format_json_schema import JSONSchema +import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import conversation @@ -36,6 +37,50 @@ from .const import DOMAIN, LOGGER MAX_TOOL_ITERATIONS = 10 +def _adjust_schema(schema: dict[str, Any]) -> None: + """Adjust the schema to be compatible with OpenRouter API.""" + if schema["type"] == "object": + if "properties" not in schema: + return + + if "required" not in schema: + schema["required"] = [] + + # Ensure all properties are required + for prop, prop_info in schema["properties"].items(): + _adjust_schema(prop_info) + if prop not in schema["required"]: + prop_info["type"] = [prop_info["type"], "null"] + schema["required"].append(prop) + + elif schema["type"] == "array": + if "items" not in schema: + return + + _adjust_schema(schema["items"]) + + +def _format_structured_output( + name: str, schema: vol.Schema, llm_api: llm.APIInstance | None +) -> JSONSchema: + """Format the schema to be compatible with OpenRouter API.""" + result: JSONSchema = { + "name": name, + "strict": True, + } + result_schema = convert( + schema, + custom_serializer=( + llm_api.custom_serializer if llm_api else llm.selector_serializer + ), + ) + + _adjust_schema(result_schema) + + result["schema"] = result_schema + return result + + def _format_tool( tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None, @@ -136,9 +181,24 @@ class OpenRouterEntity(Entity): entry_type=dr.DeviceEntryType.SERVICE, ) - async def _async_handle_chat_log(self, chat_log: conversation.ChatLog) -> None: + async def _async_handle_chat_log( + self, + chat_log: conversation.ChatLog, + structure_name: str | None = None, + structure: vol.Schema | None = None, + ) -> None: """Generate an answer for the chat log.""" + model_args = { + "model": self.model, + "user": chat_log.conversation_id, + "extra_headers": { + "X-Title": "Home Assistant", + "HTTP-Referer": "https://www.home-assistant.io/integrations/open_router", + }, + "extra_body": {"require_parameters": True}, + } + tools: list[ChatCompletionToolParam] | None = None if chat_log.llm_api: tools = [ @@ -146,33 +206,37 @@ class OpenRouterEntity(Entity): for tool in chat_log.llm_api.tools ] - messages = [ + if tools: + model_args["tools"] = tools + + model_args["messages"] = [ m for content in chat_log.content if (m := _convert_content_to_chat_message(content)) ] + if structure: + if TYPE_CHECKING: + assert structure_name is not None + model_args["response_format"] = ResponseFormatJSONSchema( + type="json_schema", + json_schema=_format_structured_output( + structure_name, structure, chat_log.llm_api + ), + ) + client = self.entry.runtime_data for _iteration in range(MAX_TOOL_ITERATIONS): try: - result = await client.chat.completions.create( - model=self.model, - messages=messages, - tools=tools or NOT_GIVEN, - user=chat_log.conversation_id, - extra_headers={ - "X-Title": "Home Assistant", - "HTTP-Referer": "https://www.home-assistant.io/integrations/open_router", - }, - ) + result = await client.chat.completions.create(**model_args) except openai.OpenAIError as err: LOGGER.error("Error talking to API: %s", err) raise HomeAssistantError("Error talking to API") from err result_message = result.choices[0].message - messages.extend( + model_args["messages"].extend( [ msg async for content in chat_log.async_add_delta_content_stream( diff --git a/homeassistant/components/open_router/strings.json b/homeassistant/components/open_router/strings.json index 91c4cc350ae..e73a65cd178 100644 --- a/homeassistant/components/open_router/strings.json +++ b/homeassistant/components/open_router/strings.json @@ -37,7 +37,28 @@ "initiate_flow": { "user": "Add conversation agent" }, - "entry_type": "Conversation agent" + "entry_type": "Conversation agent", + "abort": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "unknown": "[%key:common::config_flow::error::unknown%]" + } + }, + "ai_task_data": { + "step": { + "user": { + "data": { + "model": "[%key:component::open_router::config_subentries::conversation::step::user::data::model%]" + } + } + }, + "initiate_flow": { + "user": "Add Generate data with AI service" + }, + "entry_type": "Generate data with AI service", + "abort": { + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "unknown": "[%key:common::config_flow::error::unknown%]" + } } } } diff --git a/tests/components/open_router/conftest.py b/tests/components/open_router/conftest.py index 7bb967f369f..33ca4d790c9 100644 --- a/tests/components/open_router/conftest.py +++ b/tests/components/open_router/conftest.py @@ -49,9 +49,19 @@ def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]: return res +@pytest.fixture +def ai_task_data_subentry_data() -> dict[str, Any]: + """Mock AI task subentry data.""" + return { + CONF_MODEL: "google/gemini-1.5-pro", + } + + @pytest.fixture def mock_config_entry( - hass: HomeAssistant, conversation_subentry_data: dict[str, Any] + hass: HomeAssistant, + conversation_subentry_data: dict[str, Any], + ai_task_data_subentry_data: dict[str, Any], ) -> MockConfigEntry: """Mock a config entry.""" return MockConfigEntry( @@ -67,7 +77,14 @@ def mock_config_entry( subentry_type="conversation", title="GPT-3.5 Turbo", unique_id=None, - ) + ), + ConfigSubentryData( + data=ai_task_data_subentry_data, + subentry_id="ABCDEG", + subentry_type="ai_task_data", + title="Gemini 1.5 Pro", + unique_id=None, + ), ], ) diff --git a/tests/components/open_router/fixtures/models.json b/tests/components/open_router/fixtures/models.json index 0a35686094e..b17f584c0e6 100644 --- a/tests/components/open_router/fixtures/models.json +++ b/tests/components/open_router/fixtures/models.json @@ -85,6 +85,7 @@ "logit_bias", "logprobs", "top_logprobs", + "structured_outputs", "response_format" ] } diff --git a/tests/components/open_router/snapshots/test_ai_task.ambr b/tests/components/open_router/snapshots/test_ai_task.ambr new file mode 100644 index 00000000000..0839f6fef9b --- /dev/null +++ b/tests/components/open_router/snapshots/test_ai_task.ambr @@ -0,0 +1,53 @@ +# serializer version: 1 +# name: test_all_entities[ai_task.gemini_1_5_pro-entry] + EntityRegistryEntrySnapshot({ + 'aliases': set({ + }), + 'area_id': None, + 'capabilities': None, + 'config_entry_id': , + 'config_subentry_id': , + 'device_class': None, + 'device_id': , + 'disabled_by': None, + 'domain': 'ai_task', + 'entity_category': None, + 'entity_id': 'ai_task.gemini_1_5_pro', + 'has_entity_name': True, + 'hidden_by': None, + 'icon': None, + 'id': , + 'labels': set({ + }), + 'name': None, + 'options': dict({ + 'conversation': dict({ + 'should_expose': False, + }), + }), + 'original_device_class': None, + 'original_icon': None, + 'original_name': None, + 'platform': 'open_router', + 'previous_unique_id': None, + 'suggested_object_id': None, + 'supported_features': , + 'translation_key': None, + 'unique_id': 'ABCDEG', + 'unit_of_measurement': None, + }) +# --- +# name: test_all_entities[ai_task.gemini_1_5_pro-state] + StateSnapshot({ + 'attributes': ReadOnlyDict({ + 'friendly_name': 'Gemini 1.5 Pro', + 'supported_features': , + }), + 'context': , + 'entity_id': 'ai_task.gemini_1_5_pro', + 'last_changed': , + 'last_reported': , + 'last_updated': , + 'state': 'unknown', + }) +# --- diff --git a/tests/components/open_router/test_ai_task.py b/tests/components/open_router/test_ai_task.py new file mode 100644 index 00000000000..0b6c2933be7 --- /dev/null +++ b/tests/components/open_router/test_ai_task.py @@ -0,0 +1,210 @@ +"""Test AI Task structured data generation.""" + +from unittest.mock import AsyncMock, patch + +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +import pytest +from syrupy.assertion import SnapshotAssertion +import voluptuous as vol + +from homeassistant.components import ai_task +from homeassistant.const import Platform +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import entity_registry as er, selector + +from . import setup_integration + +from tests.common import MockConfigEntry, snapshot_platform + + +async def test_all_entities( + hass: HomeAssistant, + snapshot: SnapshotAssertion, + mock_openai_client: AsyncMock, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test all entities.""" + with patch( + "homeassistant.components.open_router.PLATFORMS", + [Platform.AI_TASK], + ): + await setup_integration(hass, mock_config_entry) + + await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id) + + +async def test_generate_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_openai_client: AsyncMock, +) -> None: + """Test AI Task data generation.""" + await setup_integration(hass, mock_config_entry) + + entity_id = "ai_task.gemini_1_5_pro" + + mock_openai_client.chat.completions.create = AsyncMock( + return_value=ChatCompletion( + id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="The test data", + role="assistant", + function_call=None, + tool_calls=None, + ), + ) + ], + created=1700000000, + model="x-ai/grok-3", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ) + ) + + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Generate test data", + ) + + assert result.data == "The test data" + + +async def test_generate_structured_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_openai_client: AsyncMock, +) -> None: + """Test AI Task structured data generation.""" + await setup_integration(hass, mock_config_entry) + + mock_openai_client.chat.completions.create = AsyncMock( + return_value=ChatCompletion( + id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content='{"characters": ["Mario", "Luigi"]}', + role="assistant", + function_call=None, + tool_calls=None, + ), + ) + ], + created=1700000000, + model="x-ai/grok-3", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ) + ) + + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id="ai_task.gemini_1_5_pro", + instructions="Generate test data", + structure=vol.Schema( + { + vol.Required("characters"): selector.selector( + { + "text": { + "multiple": True, + } + } + ) + }, + ), + ) + + assert result.data == {"characters": ["Mario", "Luigi"]} + assert mock_openai_client.chat.completions.create.call_args_list[0][1][ + "response_format" + ] == { + "json_schema": { + "name": "Test Task", + "schema": { + "properties": { + "characters": { + "items": {"type": "string"}, + "type": "array", + } + }, + "required": ["characters"], + "type": "object", + }, + "strict": True, + }, + "type": "json_schema", + } + + +async def test_generate_invalid_structured_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_openai_client: AsyncMock, +) -> None: + """Test AI Task with invalid JSON response.""" + await setup_integration(hass, mock_config_entry) + + mock_openai_client.chat.completions.create = AsyncMock( + return_value=ChatCompletion( + id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="INVALID JSON RESPONSE", + role="assistant", + function_call=None, + tool_calls=None, + ), + ) + ], + created=1700000000, + model="x-ai/grok-3", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ) + ) + + with pytest.raises( + HomeAssistantError, match="Error with OpenRouter structured response" + ): + await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id="ai_task.gemini_1_5_pro", + instructions="Generate test data", + structure=vol.Schema( + { + vol.Required("characters"): selector.selector( + { + "text": { + "multiple": True, + } + } + ) + }, + ), + ) diff --git a/tests/components/open_router/test_config_flow.py b/tests/components/open_router/test_config_flow.py index 0720f6d90f5..b406e75507b 100644 --- a/tests/components/open_router/test_config_flow.py +++ b/tests/components/open_router/test_config_flow.py @@ -110,9 +110,6 @@ async def test_create_conversation_agent( mock_config_entry: MockConfigEntry, ) -> None: """Test creating a conversation agent.""" - - mock_config_entry.add_to_hass(hass) - await setup_integration(hass, mock_config_entry) result = await hass.config_entries.subentries.async_init( @@ -152,9 +149,6 @@ async def test_create_conversation_agent_no_control( mock_config_entry: MockConfigEntry, ) -> None: """Test creating a conversation agent without control over the LLM API.""" - - mock_config_entry.add_to_hass(hass) - await setup_integration(hass, mock_config_entry) result = await hass.config_entries.subentries.async_init( @@ -184,3 +178,63 @@ async def test_create_conversation_agent_no_control( CONF_MODEL: "openai/gpt-3.5-turbo", CONF_PROMPT: "you are an assistant", } + + +async def test_create_ai_task( + hass: HomeAssistant, + mock_open_router_client: AsyncMock, + mock_openai_client: AsyncMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating an AI Task.""" + await setup_integration(hass, mock_config_entry) + + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "ai_task_data"), + context={"source": SOURCE_USER}, + ) + assert result["type"] is FlowResultType.FORM + assert not result["errors"] + assert result["step_id"] == "user" + + assert result["data_schema"].schema["model"].config["options"] == [ + {"value": "openai/gpt-4", "label": "OpenAI: GPT-4"}, + ] + + result = await hass.config_entries.subentries.async_configure( + result["flow_id"], + {CONF_MODEL: "openai/gpt-4"}, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["data"] == {CONF_MODEL: "openai/gpt-4"} + + +@pytest.mark.parametrize( + "subentry_type", + ["conversation", "ai_task_data"], +) +@pytest.mark.parametrize( + ("exception", "reason"), + [(OpenRouterError("exception"), "cannot_connect"), (Exception, "unknown")], +) +async def test_subentry_exceptions( + hass: HomeAssistant, + mock_open_router_client: AsyncMock, + mock_openai_client: AsyncMock, + mock_config_entry: MockConfigEntry, + subentry_type: str, + exception: Exception, + reason: str, +) -> None: + """Test subentry flow exceptions.""" + await setup_integration(hass, mock_config_entry) + + mock_open_router_client.get_models.side_effect = exception + + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, subentry_type), + context={"source": SOURCE_USER}, + ) + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == reason diff --git a/tests/components/open_router/test_conversation.py b/tests/components/open_router/test_conversation.py index 93f8264801a..afbdd907f93 100644 --- a/tests/components/open_router/test_conversation.py +++ b/tests/components/open_router/test_conversation.py @@ -1,6 +1,6 @@ """Tests for the OpenRouter integration.""" -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch from freezegun import freeze_time from openai.types import CompletionUsage @@ -15,6 +15,7 @@ import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import conversation +from homeassistant.const import Platform from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import entity_registry as er, intent @@ -40,7 +41,11 @@ async def test_all_entities( entity_registry: er.EntityRegistry, ) -> None: """Test all entities.""" - await setup_integration(hass, mock_config_entry) + with patch( + "homeassistant.components.open_router.PLATFORMS", + [Platform.CONVERSATION], + ): + await setup_integration(hass, mock_config_entry) await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)