mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Use OpenRouterClient to get the models (#148903)
This commit is contained in:
parent
e5c7e04329
commit
c075134845
@ -5,8 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from python_open_router import OpenRouterClient, OpenRouterError
|
||||
from python_open_router import Model, OpenRouterClient, OpenRouterError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import (
|
||||
@ -20,7 +19,6 @@ from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.httpx_client import get_async_client
|
||||
from homeassistant.helpers.selector import (
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
@ -85,7 +83,7 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the subentry flow."""
|
||||
self.options: dict[str, str] = {}
|
||||
self.models: dict[str, Model] = {}
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@ -95,14 +93,18 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
if not user_input.get(CONF_LLM_HASS_API):
|
||||
user_input.pop(CONF_LLM_HASS_API, None)
|
||||
return self.async_create_entry(
|
||||
title=self.options[user_input[CONF_MODEL]], data=user_input
|
||||
title=self.models[user_input[CONF_MODEL]].name, data=user_input
|
||||
)
|
||||
entry = self._get_entry()
|
||||
client = AsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=entry.data[CONF_API_KEY],
|
||||
http_client=get_async_client(self.hass),
|
||||
client = OpenRouterClient(
|
||||
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
|
||||
)
|
||||
models = await client.get_models()
|
||||
self.models = {model.id: model for model in models}
|
||||
options = [
|
||||
SelectOptionDict(value=model.id, label=model.name) for model in models
|
||||
]
|
||||
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
@ -110,10 +112,6 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
)
|
||||
for api in llm.async_get_apis(self.hass)
|
||||
]
|
||||
options = []
|
||||
async for model in client.with_options(timeout=10.0).models.list():
|
||||
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
|
||||
self.options[model.id] = model.name # type: ignore[attr-defined]
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=vol.Schema(
|
||||
|
@ -3,12 +3,13 @@
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
import pytest
|
||||
from python_open_router import ModelsDataWrapper
|
||||
|
||||
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||
from homeassistant.config_entries import ConfigSubentryData
|
||||
@ -17,7 +18,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.common import MockConfigEntry, async_load_fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -40,7 +41,7 @@ def enable_assist() -> bool:
|
||||
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
||||
"""Mock conversation subentry data."""
|
||||
res: dict[str, Any] = {
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "You are a helpful assistant.",
|
||||
}
|
||||
if enable_assist:
|
||||
@ -82,24 +83,8 @@ class Model:
|
||||
@pytest.fixture
|
||||
async def mock_openai_client() -> AsyncGenerator[AsyncMock]:
|
||||
"""Initialize integration."""
|
||||
with (
|
||||
patch("homeassistant.components.open_router.AsyncOpenAI") as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.open_router.config_flow.AsyncOpenAI",
|
||||
new=mock_client,
|
||||
),
|
||||
):
|
||||
with patch("homeassistant.components.open_router.AsyncOpenAI") as mock_client:
|
||||
client = mock_client.return_value
|
||||
client.with_options = MagicMock()
|
||||
client.with_options.return_value.models = MagicMock()
|
||||
client.with_options.return_value.models.list.return_value = (
|
||||
get_generator_from_data(
|
||||
[
|
||||
Model(id="gpt-4", name="GPT-4"),
|
||||
Model(id="gpt-3.5-turbo", name="GPT-3.5 Turbo"),
|
||||
],
|
||||
)
|
||||
)
|
||||
client.chat.completions.create = AsyncMock(
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
@ -128,13 +113,15 @@ async def mock_openai_client() -> AsyncGenerator[AsyncMock]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_open_router_client() -> AsyncGenerator[AsyncMock]:
|
||||
async def mock_open_router_client(hass: HomeAssistant) -> AsyncGenerator[AsyncMock]:
|
||||
"""Initialize integration."""
|
||||
with patch(
|
||||
"homeassistant.components.open_router.config_flow.OpenRouterClient",
|
||||
autospec=True,
|
||||
) as mock_client:
|
||||
client = mock_client.return_value
|
||||
models = await async_load_fixture(hass, "models.json", DOMAIN)
|
||||
client.get_models.return_value = ModelsDataWrapper.from_json(models).data
|
||||
yield client
|
||||
|
||||
|
||||
|
92
tests/components/open_router/fixtures/models.json
Normal file
92
tests/components/open_router/fixtures/models.json
Normal file
@ -0,0 +1,92 @@
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"id": "openai/gpt-3.5-turbo",
|
||||
"canonical_slug": "openai/gpt-3.5-turbo",
|
||||
"hugging_face_id": null,
|
||||
"name": "OpenAI: GPT-3.5 Turbo",
|
||||
"created": 1695859200,
|
||||
"description": "This model is a variant of GPT-3.5 Turbo tuned for instructional prompts and omitting chat-related optimizations. Training data: up to Sep 2021.",
|
||||
"context_length": 4095,
|
||||
"architecture": {
|
||||
"modality": "text->text",
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
"tokenizer": "GPT",
|
||||
"instruct_type": "chatml"
|
||||
},
|
||||
"pricing": {
|
||||
"prompt": "0.0000015",
|
||||
"completion": "0.000002",
|
||||
"request": "0",
|
||||
"image": "0",
|
||||
"web_search": "0",
|
||||
"internal_reasoning": "0"
|
||||
},
|
||||
"top_provider": {
|
||||
"context_length": 4095,
|
||||
"max_completion_tokens": 4096,
|
||||
"is_moderated": true
|
||||
},
|
||||
"per_request_limits": null,
|
||||
"supported_parameters": [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "openai/gpt-4",
|
||||
"canonical_slug": "openai/gpt-4",
|
||||
"hugging_face_id": null,
|
||||
"name": "OpenAI: GPT-4",
|
||||
"created": 1685232000,
|
||||
"description": "OpenAI's flagship model, GPT-4 is a large-scale multimodal language model capable of solving difficult problems with greater accuracy than previous models due to its broader general knowledge and advanced reasoning capabilities. Training data: up to Sep 2021.",
|
||||
"context_length": 8191,
|
||||
"architecture": {
|
||||
"modality": "text->text",
|
||||
"input_modalities": ["text"],
|
||||
"output_modalities": ["text"],
|
||||
"tokenizer": "GPT",
|
||||
"instruct_type": null
|
||||
},
|
||||
"pricing": {
|
||||
"prompt": "0.00003",
|
||||
"completion": "0.00006",
|
||||
"request": "0",
|
||||
"image": "0",
|
||||
"web_search": "0",
|
||||
"internal_reasoning": "0"
|
||||
},
|
||||
"top_provider": {
|
||||
"context_length": 8191,
|
||||
"max_completion_tokens": 4096,
|
||||
"is_moderated": true
|
||||
},
|
||||
"per_request_limits": null,
|
||||
"supported_parameters": [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -124,13 +124,14 @@ async def test_create_conversation_agent(
|
||||
assert result["step_id"] == "user"
|
||||
|
||||
assert result["data_schema"].schema["model"].config["options"] == [
|
||||
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
|
||||
{"value": "openai/gpt-3.5-turbo", "label": "OpenAI: GPT-3.5 Turbo"},
|
||||
{"value": "openai/gpt-4", "label": "OpenAI: GPT-4"},
|
||||
]
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
},
|
||||
@ -138,7 +139,7 @@ async def test_create_conversation_agent(
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
}
|
||||
@ -165,13 +166,14 @@ async def test_create_conversation_agent_no_control(
|
||||
assert result["step_id"] == "user"
|
||||
|
||||
assert result["data_schema"].schema["model"].config["options"] == [
|
||||
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
|
||||
{"value": "openai/gpt-3.5-turbo", "label": "OpenAI: GPT-3.5 Turbo"},
|
||||
{"value": "openai/gpt-4", "label": "OpenAI: GPT-4"},
|
||||
]
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: [],
|
||||
},
|
||||
@ -179,6 +181,6 @@ async def test_create_conversation_agent_no_control(
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ async def test_default_prompt(
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert mock_chat_log.content[1:] == snapshot
|
||||
call = mock_openai_client.chat.completions.create.call_args_list[0][1]
|
||||
assert call["model"] == "gpt-3.5-turbo"
|
||||
assert call["model"] == "openai/gpt-3.5-turbo"
|
||||
assert call["extra_headers"] == {
|
||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||
"X-Title": "Home Assistant",
|
||||
|
Loading…
x
Reference in New Issue
Block a user