mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Use OpenRouterClient to get the models (#148903)
This commit is contained in:
parent
e5c7e04329
commit
c075134845
@ -5,8 +5,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from python_open_router import Model, OpenRouterClient, OpenRouterError
|
||||||
from python_open_router import OpenRouterClient, OpenRouterError
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import (
|
from homeassistant.config_entries import (
|
||||||
@ -20,7 +19,6 @@ from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
|||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers.httpx_client import get_async_client
|
|
||||||
from homeassistant.helpers.selector import (
|
from homeassistant.helpers.selector import (
|
||||||
SelectOptionDict,
|
SelectOptionDict,
|
||||||
SelectSelector,
|
SelectSelector,
|
||||||
@ -85,7 +83,7 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the subentry flow."""
|
"""Initialize the subentry flow."""
|
||||||
self.options: dict[str, str] = {}
|
self.models: dict[str, Model] = {}
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
@ -95,14 +93,18 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
if not user_input.get(CONF_LLM_HASS_API):
|
if not user_input.get(CONF_LLM_HASS_API):
|
||||||
user_input.pop(CONF_LLM_HASS_API, None)
|
user_input.pop(CONF_LLM_HASS_API, None)
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=self.options[user_input[CONF_MODEL]], data=user_input
|
title=self.models[user_input[CONF_MODEL]].name, data=user_input
|
||||||
)
|
)
|
||||||
entry = self._get_entry()
|
entry = self._get_entry()
|
||||||
client = AsyncOpenAI(
|
client = OpenRouterClient(
|
||||||
base_url="https://openrouter.ai/api/v1",
|
entry.data[CONF_API_KEY], async_get_clientsession(self.hass)
|
||||||
api_key=entry.data[CONF_API_KEY],
|
|
||||||
http_client=get_async_client(self.hass),
|
|
||||||
)
|
)
|
||||||
|
models = await client.get_models()
|
||||||
|
self.models = {model.id: model for model in models}
|
||||||
|
options = [
|
||||||
|
SelectOptionDict(value=model.id, label=model.name) for model in models
|
||||||
|
]
|
||||||
|
|
||||||
hass_apis: list[SelectOptionDict] = [
|
hass_apis: list[SelectOptionDict] = [
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label=api.name,
|
label=api.name,
|
||||||
@ -110,10 +112,6 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
)
|
)
|
||||||
for api in llm.async_get_apis(self.hass)
|
for api in llm.async_get_apis(self.hass)
|
||||||
]
|
]
|
||||||
options = []
|
|
||||||
async for model in client.with_options(timeout=10.0).models.list():
|
|
||||||
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
|
|
||||||
self.options[model.id] = model.name # type: ignore[attr-defined]
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user",
|
step_id="user",
|
||||||
data_schema=vol.Schema(
|
data_schema=vol.Schema(
|
||||||
|
@ -3,12 +3,13 @@
|
|||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from openai.types import CompletionUsage
|
from openai.types import CompletionUsage
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||||
from openai.types.chat.chat_completion import Choice
|
from openai.types.chat.chat_completion import Choice
|
||||||
import pytest
|
import pytest
|
||||||
|
from python_open_router import ModelsDataWrapper
|
||||||
|
|
||||||
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||||
from homeassistant.config_entries import ConfigSubentryData
|
from homeassistant.config_entries import ConfigSubentryData
|
||||||
@ -17,7 +18,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry, async_load_fixture
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -40,7 +41,7 @@ def enable_assist() -> bool:
|
|||||||
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
||||||
"""Mock conversation subentry data."""
|
"""Mock conversation subentry data."""
|
||||||
res: dict[str, Any] = {
|
res: dict[str, Any] = {
|
||||||
CONF_MODEL: "gpt-3.5-turbo",
|
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||||
CONF_PROMPT: "You are a helpful assistant.",
|
CONF_PROMPT: "You are a helpful assistant.",
|
||||||
}
|
}
|
||||||
if enable_assist:
|
if enable_assist:
|
||||||
@ -82,24 +83,8 @@ class Model:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_openai_client() -> AsyncGenerator[AsyncMock]:
|
async def mock_openai_client() -> AsyncGenerator[AsyncMock]:
|
||||||
"""Initialize integration."""
|
"""Initialize integration."""
|
||||||
with (
|
with patch("homeassistant.components.open_router.AsyncOpenAI") as mock_client:
|
||||||
patch("homeassistant.components.open_router.AsyncOpenAI") as mock_client,
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.open_router.config_flow.AsyncOpenAI",
|
|
||||||
new=mock_client,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
client = mock_client.return_value
|
client = mock_client.return_value
|
||||||
client.with_options = MagicMock()
|
|
||||||
client.with_options.return_value.models = MagicMock()
|
|
||||||
client.with_options.return_value.models.list.return_value = (
|
|
||||||
get_generator_from_data(
|
|
||||||
[
|
|
||||||
Model(id="gpt-4", name="GPT-4"),
|
|
||||||
Model(id="gpt-3.5-turbo", name="GPT-3.5 Turbo"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
client.chat.completions.create = AsyncMock(
|
client.chat.completions.create = AsyncMock(
|
||||||
return_value=ChatCompletion(
|
return_value=ChatCompletion(
|
||||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||||
@ -128,13 +113,15 @@ async def mock_openai_client() -> AsyncGenerator[AsyncMock]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_open_router_client() -> AsyncGenerator[AsyncMock]:
|
async def mock_open_router_client(hass: HomeAssistant) -> AsyncGenerator[AsyncMock]:
|
||||||
"""Initialize integration."""
|
"""Initialize integration."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.open_router.config_flow.OpenRouterClient",
|
"homeassistant.components.open_router.config_flow.OpenRouterClient",
|
||||||
autospec=True,
|
autospec=True,
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
client = mock_client.return_value
|
client = mock_client.return_value
|
||||||
|
models = await async_load_fixture(hass, "models.json", DOMAIN)
|
||||||
|
client.get_models.return_value = ModelsDataWrapper.from_json(models).data
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
92
tests/components/open_router/fixtures/models.json
Normal file
92
tests/components/open_router/fixtures/models.json
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
{
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-3.5-turbo",
|
||||||
|
"canonical_slug": "openai/gpt-3.5-turbo",
|
||||||
|
"hugging_face_id": null,
|
||||||
|
"name": "OpenAI: GPT-3.5 Turbo",
|
||||||
|
"created": 1695859200,
|
||||||
|
"description": "This model is a variant of GPT-3.5 Turbo tuned for instructional prompts and omitting chat-related optimizations. Training data: up to Sep 2021.",
|
||||||
|
"context_length": 4095,
|
||||||
|
"architecture": {
|
||||||
|
"modality": "text->text",
|
||||||
|
"input_modalities": ["text"],
|
||||||
|
"output_modalities": ["text"],
|
||||||
|
"tokenizer": "GPT",
|
||||||
|
"instruct_type": "chatml"
|
||||||
|
},
|
||||||
|
"pricing": {
|
||||||
|
"prompt": "0.0000015",
|
||||||
|
"completion": "0.000002",
|
||||||
|
"request": "0",
|
||||||
|
"image": "0",
|
||||||
|
"web_search": "0",
|
||||||
|
"internal_reasoning": "0"
|
||||||
|
},
|
||||||
|
"top_provider": {
|
||||||
|
"context_length": 4095,
|
||||||
|
"max_completion_tokens": 4096,
|
||||||
|
"is_moderated": true
|
||||||
|
},
|
||||||
|
"per_request_limits": null,
|
||||||
|
"supported_parameters": [
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"stop",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"seed",
|
||||||
|
"logit_bias",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
"response_format"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "openai/gpt-4",
|
||||||
|
"canonical_slug": "openai/gpt-4",
|
||||||
|
"hugging_face_id": null,
|
||||||
|
"name": "OpenAI: GPT-4",
|
||||||
|
"created": 1685232000,
|
||||||
|
"description": "OpenAI's flagship model, GPT-4 is a large-scale multimodal language model capable of solving difficult problems with greater accuracy than previous models due to its broader general knowledge and advanced reasoning capabilities. Training data: up to Sep 2021.",
|
||||||
|
"context_length": 8191,
|
||||||
|
"architecture": {
|
||||||
|
"modality": "text->text",
|
||||||
|
"input_modalities": ["text"],
|
||||||
|
"output_modalities": ["text"],
|
||||||
|
"tokenizer": "GPT",
|
||||||
|
"instruct_type": null
|
||||||
|
},
|
||||||
|
"pricing": {
|
||||||
|
"prompt": "0.00003",
|
||||||
|
"completion": "0.00006",
|
||||||
|
"request": "0",
|
||||||
|
"image": "0",
|
||||||
|
"web_search": "0",
|
||||||
|
"internal_reasoning": "0"
|
||||||
|
},
|
||||||
|
"top_provider": {
|
||||||
|
"context_length": 8191,
|
||||||
|
"max_completion_tokens": 4096,
|
||||||
|
"is_moderated": true
|
||||||
|
},
|
||||||
|
"per_request_limits": null,
|
||||||
|
"supported_parameters": [
|
||||||
|
"max_tokens",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"stop",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"seed",
|
||||||
|
"logit_bias",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
"response_format"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -124,13 +124,14 @@ async def test_create_conversation_agent(
|
|||||||
assert result["step_id"] == "user"
|
assert result["step_id"] == "user"
|
||||||
|
|
||||||
assert result["data_schema"].schema["model"].config["options"] == [
|
assert result["data_schema"].schema["model"].config["options"] == [
|
||||||
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
|
{"value": "openai/gpt-3.5-turbo", "label": "OpenAI: GPT-3.5 Turbo"},
|
||||||
|
{"value": "openai/gpt-4", "label": "OpenAI: GPT-4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await hass.config_entries.subentries.async_configure(
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{
|
{
|
||||||
CONF_MODEL: "gpt-3.5-turbo",
|
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||||
CONF_PROMPT: "you are an assistant",
|
CONF_PROMPT: "you are an assistant",
|
||||||
CONF_LLM_HASS_API: ["assist"],
|
CONF_LLM_HASS_API: ["assist"],
|
||||||
},
|
},
|
||||||
@ -138,7 +139,7 @@ async def test_create_conversation_agent(
|
|||||||
|
|
||||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert result["data"] == {
|
assert result["data"] == {
|
||||||
CONF_MODEL: "gpt-3.5-turbo",
|
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||||
CONF_PROMPT: "you are an assistant",
|
CONF_PROMPT: "you are an assistant",
|
||||||
CONF_LLM_HASS_API: ["assist"],
|
CONF_LLM_HASS_API: ["assist"],
|
||||||
}
|
}
|
||||||
@ -165,13 +166,14 @@ async def test_create_conversation_agent_no_control(
|
|||||||
assert result["step_id"] == "user"
|
assert result["step_id"] == "user"
|
||||||
|
|
||||||
assert result["data_schema"].schema["model"].config["options"] == [
|
assert result["data_schema"].schema["model"].config["options"] == [
|
||||||
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
|
{"value": "openai/gpt-3.5-turbo", "label": "OpenAI: GPT-3.5 Turbo"},
|
||||||
|
{"value": "openai/gpt-4", "label": "OpenAI: GPT-4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
result = await hass.config_entries.subentries.async_configure(
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{
|
{
|
||||||
CONF_MODEL: "gpt-3.5-turbo",
|
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||||
CONF_PROMPT: "you are an assistant",
|
CONF_PROMPT: "you are an assistant",
|
||||||
CONF_LLM_HASS_API: [],
|
CONF_LLM_HASS_API: [],
|
||||||
},
|
},
|
||||||
@ -179,6 +181,6 @@ async def test_create_conversation_agent_no_control(
|
|||||||
|
|
||||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert result["data"] == {
|
assert result["data"] == {
|
||||||
CONF_MODEL: "gpt-3.5-turbo",
|
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||||
CONF_PROMPT: "you are an assistant",
|
CONF_PROMPT: "you are an assistant",
|
||||||
}
|
}
|
||||||
|
@ -65,7 +65,7 @@ async def test_default_prompt(
|
|||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
assert mock_chat_log.content[1:] == snapshot
|
assert mock_chat_log.content[1:] == snapshot
|
||||||
call = mock_openai_client.chat.completions.create.call_args_list[0][1]
|
call = mock_openai_client.chat.completions.create.call_args_list[0][1]
|
||||||
assert call["model"] == "gpt-3.5-turbo"
|
assert call["model"] == "openai/gpt-3.5-turbo"
|
||||||
assert call["extra_headers"] == {
|
assert call["extra_headers"] == {
|
||||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||||
"X-Title": "Home Assistant",
|
"X-Title": "Home Assistant",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user