Add support for calling tools in Open Router (#148881)

This commit is contained in:
Joost Lekkerkerker 2025-07-18 05:49:27 +02:00 committed by GitHub
parent 073ea813f0
commit 50688bbd69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 497 additions and 51 deletions

View File

@ -16,8 +16,9 @@ from homeassistant.config_entries import (
ConfigSubentryFlow,
SubentryFlowResult,
)
from homeassistant.const import CONF_API_KEY, CONF_MODEL
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
from homeassistant.core import callback
from homeassistant.helpers import llm
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.selector import (
@ -25,9 +26,10 @@ from homeassistant.helpers.selector import (
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
TemplateSelector,
)
from .const import DOMAIN
from .const import CONF_PROMPT, DOMAIN, RECOMMENDED_CONVERSATION_OPTIONS
_LOGGER = logging.getLogger(__name__)
@ -90,6 +92,8 @@ class ConversationFlowHandler(ConfigSubentryFlow):
) -> SubentryFlowResult:
"""User flow to create a sensor subentry."""
if user_input is not None:
if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None)
return self.async_create_entry(
title=self.options[user_input[CONF_MODEL]], data=user_input
)
@ -99,11 +103,17 @@ class ConversationFlowHandler(ConfigSubentryFlow):
api_key=entry.data[CONF_API_KEY],
http_client=get_async_client(self.hass),
)
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(self.hass)
]
options = []
async for model in client.with_options(timeout=10.0).models.list():
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
self.options[model.id] = model.name # type: ignore[attr-defined]
return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
@ -113,6 +123,20 @@ class ConversationFlowHandler(ConfigSubentryFlow):
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
),
),
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": RECOMMENDED_CONVERSATION_OPTIONS[
CONF_PROMPT
]
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
default=RECOMMENDED_CONVERSATION_OPTIONS[CONF_LLM_HASS_API],
): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True)
),
}
),
)

View File

@ -2,5 +2,17 @@
import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "open_router"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
CONF_RECOMMENDED = "recommended"
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}

View File

@ -1,25 +1,39 @@
"""Conversation support for OpenRouter."""
from typing import Literal
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal
import openai
from openai import NOT_GIVEN
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenRouterConfigEntry
from .const import DOMAIN, LOGGER
from .const import CONF_PROMPT, DOMAIN, LOGGER
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry(
@ -35,13 +49,31 @@ async def async_setup_entry(
)
def _format_tool(
tool: llm.Tool,
custom_serializer: Callable[[Any], Any] | None,
) -> ChatCompletionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec)
def _convert_content_to_chat_message(
content: conversation.Content,
) -> ChatCompletionMessageParam | None:
"""Convert any native chat message for this agent to the native format."""
LOGGER.debug("_convert_content_to_chat_message=%s", content)
if isinstance(content, conversation.ToolResultContent):
return None
return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
role: Literal["user", "assistant", "system"] = content.role
if role == "system" and content.content:
@ -51,13 +83,55 @@ def _convert_content_to_chat_message(
return ChatCompletionUserMessageParam(role="user", content=content.content)
if role == "assistant":
return ChatCompletionAssistantMessageParam(
role="assistant", content=content.content
param = ChatCompletionAssistantMessageParam(
role="assistant",
content=content.content,
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
param["tool_calls"] = [
ChatCompletionMessageToolCallParam(
type="function",
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
)
for tool_call in content.tool_calls
]
return param
LOGGER.warning("Could not convert message to Completions API: %s", content)
return None
def _decode_tool_arguments(arguments: str) -> Any:
"""Decode tool call arguments."""
try:
return json.loads(arguments)
except json.JSONDecodeError as err:
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
async def _transform_response(
message: ChatCompletionMessage,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the OpenRouter message to a ChatLog format."""
data: conversation.AssistantContentDeltaDict = {
"role": message.role,
"content": message.content,
}
if message.tool_calls:
data["tool_calls"] = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=_decode_tool_arguments(tool_call.function.arguments),
)
for tool_call in message.tool_calls
]
yield data
class OpenRouterConversationEntity(conversation.ConversationEntity):
"""OpenRouter conversation agent."""
@ -75,6 +149,10 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
name=subentry.title,
entry_type=DeviceEntryType.SERVICE,
)
if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
@property
def supported_languages(self) -> list[str] | Literal["*"]:
@ -93,12 +171,19 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API),
None,
options.get(CONF_PROMPT),
user_input.extra_system_prompt,
)
except conversation.ConverseError as err:
return err.as_conversation_result()
tools: list[ChatCompletionToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
messages = [
m
for content in chat_log.content
@ -107,10 +192,12 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
client = self.entry.runtime_data
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=self.model,
messages=messages,
tools=tools or NOT_GIVEN,
user=chat_log.conversation_id,
extra_headers={
"X-Title": "Home Assistant",
@ -123,11 +210,16 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
result_message = result.choices[0].message
chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=result_message.content,
messages.extend(
[
msg
async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_response(result_message)
)
if (msg := _convert_content_to_chat_message(content))
]
)
if not chat_log.unresponded_tool_results:
break
return conversation.async_get_result_from_chat_log(user_input, chat_log)

View File

@ -24,7 +24,13 @@
"user": {
"description": "Configure the new conversation agent",
"data": {
"model": "Model"
"model": "Model",
"prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
},
"data_description": {
"model": "The model to use for the conversation agent",
"prompt": "Instruct how the LLM should respond. This can be a template."
}
}
},

View File

@ -2,6 +2,7 @@
from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from openai.types import CompletionUsage
@ -9,10 +10,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
import pytest
from homeassistant.components.open_router.const import DOMAIN
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
from homeassistant.config_entries import ConfigSubentryData
from homeassistant.const import CONF_API_KEY, CONF_MODEL
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
@ -29,7 +31,27 @@ def mock_setup_entry() -> Generator[AsyncMock]:
@pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
def enable_assist() -> bool:
"""Mock conversation subentry data."""
return False
@pytest.fixture
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
"""Mock conversation subentry data."""
res: dict[str, Any] = {
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "You are a helpful assistant.",
}
if enable_assist:
res[CONF_LLM_HASS_API] = [llm.LLM_API_ASSIST]
return res
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
) -> MockConfigEntry:
"""Mock a config entry."""
return MockConfigEntry(
title="OpenRouter",
@ -39,7 +61,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
},
subentries_data=[
ConfigSubentryData(
data={CONF_MODEL: "gpt-3.5-turbo"},
data=conversation_subentry_data,
subentry_id="ABCDEF",
subentry_type="conversation",
title="GPT-3.5 Turbo",

View File

@ -1,4 +1,108 @@
# serializer version: 1
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'config_subentry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'conversation',
'entity_category': None,
'entity_id': 'conversation.gpt_3_5_turbo',
'has_entity_name': True,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
'conversation': dict({
'should_expose': False,
}),
}),
'original_device_class': None,
'original_icon': None,
'original_name': None,
'platform': 'open_router',
'previous_unique_id': None,
'suggested_object_id': None,
'supported_features': <ConversationEntityFeature: 1>,
'translation_key': None,
'unique_id': 'ABCDEF',
'unit_of_measurement': None,
})
# ---
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'friendly_name': 'GPT-3.5 Turbo',
'supported_features': <ConversationEntityFeature: 1>,
}),
'context': <ANY>,
'entity_id': 'conversation.gpt_3_5_turbo',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'unknown',
})
# ---
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'config_subentry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'conversation',
'entity_category': None,
'entity_id': 'conversation.gpt_3_5_turbo',
'has_entity_name': True,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
'conversation': dict({
'should_expose': False,
}),
}),
'original_device_class': None,
'original_icon': None,
'original_name': None,
'platform': 'open_router',
'previous_unique_id': None,
'suggested_object_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': 'ABCDEF',
'unit_of_measurement': None,
})
# ---
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'friendly_name': 'GPT-3.5 Turbo',
'supported_features': <ConversationEntityFeature: 0>,
}),
'context': <ANY>,
'entity_id': 'conversation.gpt_3_5_turbo',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'unknown',
})
# ---
# name: test_default_prompt
list([
dict({
@ -14,3 +118,39 @@
}),
])
# ---
# name: test_function_call[True]
list([
dict({
'attachments': None,
'content': 'Please call the test function',
'role': 'user',
}),
dict({
'agent_id': 'conversation.gpt_3_5_turbo',
'content': None,
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'call_call_1',
'tool_args': dict({
'param1': 'call1',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'conversation.gpt_3_5_turbo',
'role': 'tool_result',
'tool_call_id': 'call_call_1',
'tool_name': 'test_tool',
'tool_result': 'value1',
}),
dict({
'agent_id': 'conversation.gpt_3_5_turbo',
'content': 'I have successfully called the function',
'role': 'assistant',
'tool_calls': None,
}),
])
# ---

View File

@ -5,9 +5,9 @@ from unittest.mock import AsyncMock
import pytest
from python_open_router import OpenRouterError
from homeassistant.components.open_router.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER, ConfigSubentry
from homeassistant.const import CONF_API_KEY, CONF_MODEL
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
from homeassistant.config_entries import SOURCE_USER
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@ -129,18 +129,56 @@ async def test_create_conversation_agent(
result = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_MODEL: "gpt-3.5-turbo"},
{
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
CONF_LLM_HASS_API: ["assist"],
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
subentry_id = list(mock_config_entry.subentries)[0]
assert (
ConfigSubentry(
data={CONF_MODEL: "gpt-3.5-turbo"},
subentry_id=subentry_id,
subentry_type="conversation",
title="GPT-3.5 Turbo",
unique_id=None,
assert result["data"] == {
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
CONF_LLM_HASS_API: ["assist"],
}
async def test_create_conversation_agent_no_control(
hass: HomeAssistant,
mock_open_router_client: AsyncMock,
mock_openai_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation agent without control over the LLM API."""
mock_config_entry.add_to_hass(hass)
await setup_integration(hass, mock_config_entry)
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": SOURCE_USER},
)
in mock_config_entry.subentries.values()
assert result["type"] is FlowResultType.FORM
assert not result["errors"]
assert result["step_id"] == "user"
assert result["data_schema"].schema["model"].config["options"] == [
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
]
result = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
CONF_LLM_HASS_API: [],
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == {
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
}

View File

@ -3,16 +3,24 @@
from unittest.mock import AsyncMock
from freezegun import freeze_time
from openai.types import CompletionUsage
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message_tool_call import Function
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from homeassistant.helpers import entity_registry as er, intent
from . import setup_integration
from tests.common import MockConfigEntry
from tests.common import MockConfigEntry, snapshot_platform
from tests.components.conversation import MockChatLog, mock_chat_log # noqa: F401
@ -23,11 +31,23 @@ def freeze_the_time():
yield
@pytest.mark.parametrize("enable_assist", [True, False], ids=["assist", "no_assist"])
async def test_all_entities(
hass: HomeAssistant,
snapshot: SnapshotAssertion,
mock_openai_client: AsyncMock,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test all entities."""
await setup_integration(hass, mock_config_entry)
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
async def test_default_prompt(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
mock_openai_client: AsyncMock,
mock_chat_log: MockChatLog, # noqa: F811
@ -50,3 +70,95 @@ async def test_default_prompt(
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
"X-Title": "Home Assistant",
}
@pytest.mark.parametrize("enable_assist", [True])
async def test_function_call(
hass: HomeAssistant,
mock_chat_log: MockChatLog, # noqa: F811
mock_config_entry: MockConfigEntry,
snapshot: SnapshotAssertion,
mock_openai_client: AsyncMock,
) -> None:
"""Test function call from the assistant."""
await setup_integration(hass, mock_config_entry)
mock_chat_log.mock_tool_results(
{
"call_call_1": "value1",
"call_call_2": "value2",
}
)
async def completion_result(*args, messages, **kwargs):
for message in messages:
role = message["role"] if isinstance(message, dict) else message.role
if role == "tool":
return ChatCompletion(
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="I have successfully called the function",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
return ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="call_call_1",
function=Function(
arguments='{"param1":"call1"}',
name="test_tool",
),
type="function",
)
],
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
mock_openai_client.chat.completions.create = completion_result
result = await conversation.async_converse(
hass,
"Please call the test function",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.gpt_3_5_turbo",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
# Don't test the prompt, as it's not deterministic
assert mock_chat_log.content[1:] == snapshot