Add support for calling tools in Open Router (#148881)

This commit is contained in:
Joost Lekkerkerker 2025-07-18 05:49:27 +02:00 committed by GitHub
parent 073ea813f0
commit 50688bbd69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 497 additions and 51 deletions

View File

@ -16,8 +16,9 @@ from homeassistant.config_entries import (
ConfigSubentryFlow, ConfigSubentryFlow,
SubentryFlowResult, SubentryFlowResult,
) )
from homeassistant.const import CONF_API_KEY, CONF_MODEL from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import llm
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.httpx_client import get_async_client from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
@ -25,9 +26,10 @@ from homeassistant.helpers.selector import (
SelectSelector, SelectSelector,
SelectSelectorConfig, SelectSelectorConfig,
SelectSelectorMode, SelectSelectorMode,
TemplateSelector,
) )
from .const import DOMAIN from .const import CONF_PROMPT, DOMAIN, RECOMMENDED_CONVERSATION_OPTIONS
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -90,6 +92,8 @@ class ConversationFlowHandler(ConfigSubentryFlow):
) -> SubentryFlowResult: ) -> SubentryFlowResult:
"""User flow to create a sensor subentry.""" """User flow to create a sensor subentry."""
if user_input is not None: if user_input is not None:
if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None)
return self.async_create_entry( return self.async_create_entry(
title=self.options[user_input[CONF_MODEL]], data=user_input title=self.options[user_input[CONF_MODEL]], data=user_input
) )
@ -99,11 +103,17 @@ class ConversationFlowHandler(ConfigSubentryFlow):
api_key=entry.data[CONF_API_KEY], api_key=entry.data[CONF_API_KEY],
http_client=get_async_client(self.hass), http_client=get_async_client(self.hass),
) )
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(self.hass)
]
options = [] options = []
async for model in client.with_options(timeout=10.0).models.list(): async for model in client.with_options(timeout=10.0).models.list():
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined] options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
self.options[model.id] = model.name # type: ignore[attr-defined] self.options[model.id] = model.name # type: ignore[attr-defined]
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="user",
data_schema=vol.Schema( data_schema=vol.Schema(
@ -113,6 +123,20 @@ class ConversationFlowHandler(ConfigSubentryFlow):
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
), ),
), ),
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": RECOMMENDED_CONVERSATION_OPTIONS[
CONF_PROMPT
]
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
default=RECOMMENDED_CONVERSATION_OPTIONS[CONF_LLM_HASS_API],
): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True)
),
} }
), ),
) )

View File

@ -2,5 +2,17 @@
import logging import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "open_router" DOMAIN = "open_router"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
CONF_RECOMMENDED = "recommended"
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}

View File

@ -1,25 +1,39 @@
"""Conversation support for OpenRouter.""" """Conversation support for OpenRouter."""
from typing import Literal from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal
import openai import openai
from openai import NOT_GIVEN
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam, ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam, ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionToolParam,
ChatCompletionUserMessageParam, ChatCompletionUserMessageParam,
) )
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
from voluptuous_openapi import convert
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry from homeassistant.config_entries import ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenRouterConfigEntry from . import OpenRouterConfigEntry
from .const import DOMAIN, LOGGER from .const import CONF_PROMPT, DOMAIN, LOGGER
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry( async def async_setup_entry(
@ -35,13 +49,31 @@ async def async_setup_entry(
) )
def _format_tool(
tool: llm.Tool,
custom_serializer: Callable[[Any], Any] | None,
) -> ChatCompletionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec)
def _convert_content_to_chat_message( def _convert_content_to_chat_message(
content: conversation.Content, content: conversation.Content,
) -> ChatCompletionMessageParam | None: ) -> ChatCompletionMessageParam | None:
"""Convert any native chat message for this agent to the native format.""" """Convert any native chat message for this agent to the native format."""
LOGGER.debug("_convert_content_to_chat_message=%s", content) LOGGER.debug("_convert_content_to_chat_message=%s", content)
if isinstance(content, conversation.ToolResultContent): if isinstance(content, conversation.ToolResultContent):
return None return ChatCompletionToolMessageParam(
role="tool",
tool_call_id=content.tool_call_id,
content=json.dumps(content.tool_result),
)
role: Literal["user", "assistant", "system"] = content.role role: Literal["user", "assistant", "system"] = content.role
if role == "system" and content.content: if role == "system" and content.content:
@ -51,13 +83,55 @@ def _convert_content_to_chat_message(
return ChatCompletionUserMessageParam(role="user", content=content.content) return ChatCompletionUserMessageParam(role="user", content=content.content)
if role == "assistant": if role == "assistant":
return ChatCompletionAssistantMessageParam( param = ChatCompletionAssistantMessageParam(
role="assistant", content=content.content role="assistant",
content=content.content,
) )
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
param["tool_calls"] = [
ChatCompletionMessageToolCallParam(
type="function",
id=tool_call.id,
function=Function(
arguments=json.dumps(tool_call.tool_args),
name=tool_call.tool_name,
),
)
for tool_call in content.tool_calls
]
return param
LOGGER.warning("Could not convert message to Completions API: %s", content) LOGGER.warning("Could not convert message to Completions API: %s", content)
return None return None
def _decode_tool_arguments(arguments: str) -> Any:
"""Decode tool call arguments."""
try:
return json.loads(arguments)
except json.JSONDecodeError as err:
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
async def _transform_response(
message: ChatCompletionMessage,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the OpenRouter message to a ChatLog format."""
data: conversation.AssistantContentDeltaDict = {
"role": message.role,
"content": message.content,
}
if message.tool_calls:
data["tool_calls"] = [
llm.ToolInput(
id=tool_call.id,
tool_name=tool_call.function.name,
tool_args=_decode_tool_arguments(tool_call.function.arguments),
)
for tool_call in message.tool_calls
]
yield data
class OpenRouterConversationEntity(conversation.ConversationEntity): class OpenRouterConversationEntity(conversation.ConversationEntity):
"""OpenRouter conversation agent.""" """OpenRouter conversation agent."""
@ -75,6 +149,10 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
name=subentry.title, name=subentry.title,
entry_type=DeviceEntryType.SERVICE, entry_type=DeviceEntryType.SERVICE,
) )
if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
@property @property
def supported_languages(self) -> list[str] | Literal["*"]: def supported_languages(self) -> list[str] | Literal["*"]:
@ -93,12 +171,19 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
await chat_log.async_provide_llm_data( await chat_log.async_provide_llm_data(
user_input.as_llm_context(DOMAIN), user_input.as_llm_context(DOMAIN),
options.get(CONF_LLM_HASS_API), options.get(CONF_LLM_HASS_API),
None, options.get(CONF_PROMPT),
user_input.extra_system_prompt, user_input.extra_system_prompt,
) )
except conversation.ConverseError as err: except conversation.ConverseError as err:
return err.as_conversation_result() return err.as_conversation_result()
tools: list[ChatCompletionToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
messages = [ messages = [
m m
for content in chat_log.content for content in chat_log.content
@ -107,27 +192,34 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
client = self.entry.runtime_data client = self.entry.runtime_data
try: for _iteration in range(MAX_TOOL_ITERATIONS):
result = await client.chat.completions.create( try:
model=self.model, result = await client.chat.completions.create(
messages=messages, model=self.model,
user=chat_log.conversation_id, messages=messages,
extra_headers={ tools=tools or NOT_GIVEN,
"X-Title": "Home Assistant", user=chat_log.conversation_id,
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router", extra_headers={
}, "X-Title": "Home Assistant",
) "HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
except openai.OpenAIError as err: },
LOGGER.error("Error talking to API: %s", err) )
raise HomeAssistantError("Error talking to API") from err except openai.OpenAIError as err:
LOGGER.error("Error talking to API: %s", err)
raise HomeAssistantError("Error talking to API") from err
result_message = result.choices[0].message result_message = result.choices[0].message
chat_log.async_add_assistant_content_without_tools( messages.extend(
conversation.AssistantContent( [
agent_id=user_input.agent_id, msg
content=result_message.content, async for content in chat_log.async_add_delta_content_stream(
user_input.agent_id, _transform_response(result_message)
)
if (msg := _convert_content_to_chat_message(content))
]
) )
) if not chat_log.unresponded_tool_results:
break
return conversation.async_get_result_from_chat_log(user_input, chat_log) return conversation.async_get_result_from_chat_log(user_input, chat_log)

View File

@ -24,7 +24,13 @@
"user": { "user": {
"description": "Configure the new conversation agent", "description": "Configure the new conversation agent",
"data": { "data": {
"model": "Model" "model": "Model",
"prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
},
"data_description": {
"model": "The model to use for the conversation agent",
"prompt": "Instruct how the LLM should respond. This can be a template."
} }
} }
}, },

View File

@ -2,6 +2,7 @@
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from openai.types import CompletionUsage from openai.types import CompletionUsage
@ -9,10 +10,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion import Choice
import pytest import pytest
from homeassistant.components.open_router.const import DOMAIN from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
from homeassistant.config_entries import ConfigSubentryData from homeassistant.config_entries import ConfigSubentryData
from homeassistant.const import CONF_API_KEY, CONF_MODEL from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -29,7 +31,27 @@ def mock_setup_entry() -> Generator[AsyncMock]:
@pytest.fixture @pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: def enable_assist() -> bool:
"""Mock conversation subentry data."""
return False
@pytest.fixture
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
"""Mock conversation subentry data."""
res: dict[str, Any] = {
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "You are a helpful assistant.",
}
if enable_assist:
res[CONF_LLM_HASS_API] = [llm.LLM_API_ASSIST]
return res
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
) -> MockConfigEntry:
"""Mock a config entry.""" """Mock a config entry."""
return MockConfigEntry( return MockConfigEntry(
title="OpenRouter", title="OpenRouter",
@ -39,7 +61,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
}, },
subentries_data=[ subentries_data=[
ConfigSubentryData( ConfigSubentryData(
data={CONF_MODEL: "gpt-3.5-turbo"}, data=conversation_subentry_data,
subentry_id="ABCDEF", subentry_id="ABCDEF",
subentry_type="conversation", subentry_type="conversation",
title="GPT-3.5 Turbo", title="GPT-3.5 Turbo",

View File

@ -1,4 +1,108 @@
# serializer version: 1 # serializer version: 1
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'config_subentry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'conversation',
'entity_category': None,
'entity_id': 'conversation.gpt_3_5_turbo',
'has_entity_name': True,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
'conversation': dict({
'should_expose': False,
}),
}),
'original_device_class': None,
'original_icon': None,
'original_name': None,
'platform': 'open_router',
'previous_unique_id': None,
'suggested_object_id': None,
'supported_features': <ConversationEntityFeature: 1>,
'translation_key': None,
'unique_id': 'ABCDEF',
'unit_of_measurement': None,
})
# ---
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'friendly_name': 'GPT-3.5 Turbo',
'supported_features': <ConversationEntityFeature: 1>,
}),
'context': <ANY>,
'entity_id': 'conversation.gpt_3_5_turbo',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'unknown',
})
# ---
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-entry]
EntityRegistryEntrySnapshot({
'aliases': set({
}),
'area_id': None,
'capabilities': None,
'config_entry_id': <ANY>,
'config_subentry_id': <ANY>,
'device_class': None,
'device_id': <ANY>,
'disabled_by': None,
'domain': 'conversation',
'entity_category': None,
'entity_id': 'conversation.gpt_3_5_turbo',
'has_entity_name': True,
'hidden_by': None,
'icon': None,
'id': <ANY>,
'labels': set({
}),
'name': None,
'options': dict({
'conversation': dict({
'should_expose': False,
}),
}),
'original_device_class': None,
'original_icon': None,
'original_name': None,
'platform': 'open_router',
'previous_unique_id': None,
'suggested_object_id': None,
'supported_features': 0,
'translation_key': None,
'unique_id': 'ABCDEF',
'unit_of_measurement': None,
})
# ---
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-state]
StateSnapshot({
'attributes': ReadOnlyDict({
'friendly_name': 'GPT-3.5 Turbo',
'supported_features': <ConversationEntityFeature: 0>,
}),
'context': <ANY>,
'entity_id': 'conversation.gpt_3_5_turbo',
'last_changed': <ANY>,
'last_reported': <ANY>,
'last_updated': <ANY>,
'state': 'unknown',
})
# ---
# name: test_default_prompt # name: test_default_prompt
list([ list([
dict({ dict({
@ -14,3 +118,39 @@
}), }),
]) ])
# --- # ---
# name: test_function_call[True]
list([
dict({
'attachments': None,
'content': 'Please call the test function',
'role': 'user',
}),
dict({
'agent_id': 'conversation.gpt_3_5_turbo',
'content': None,
'role': 'assistant',
'tool_calls': list([
dict({
'id': 'call_call_1',
'tool_args': dict({
'param1': 'call1',
}),
'tool_name': 'test_tool',
}),
]),
}),
dict({
'agent_id': 'conversation.gpt_3_5_turbo',
'role': 'tool_result',
'tool_call_id': 'call_call_1',
'tool_name': 'test_tool',
'tool_result': 'value1',
}),
dict({
'agent_id': 'conversation.gpt_3_5_turbo',
'content': 'I have successfully called the function',
'role': 'assistant',
'tool_calls': None,
}),
])
# ---

View File

@ -5,9 +5,9 @@ from unittest.mock import AsyncMock
import pytest import pytest
from python_open_router import OpenRouterError from python_open_router import OpenRouterError
from homeassistant.components.open_router.const import DOMAIN from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
from homeassistant.config_entries import SOURCE_USER, ConfigSubentry from homeassistant.config_entries import SOURCE_USER
from homeassistant.const import CONF_API_KEY, CONF_MODEL from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@ -129,18 +129,56 @@ async def test_create_conversation_agent(
result = await hass.config_entries.subentries.async_configure( result = await hass.config_entries.subentries.async_configure(
result["flow_id"], result["flow_id"],
{CONF_MODEL: "gpt-3.5-turbo"}, {
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
CONF_LLM_HASS_API: ["assist"],
},
) )
assert result["type"] is FlowResultType.CREATE_ENTRY assert result["type"] is FlowResultType.CREATE_ENTRY
subentry_id = list(mock_config_entry.subentries)[0] assert result["data"] == {
assert ( CONF_MODEL: "gpt-3.5-turbo",
ConfigSubentry( CONF_PROMPT: "you are an assistant",
data={CONF_MODEL: "gpt-3.5-turbo"}, CONF_LLM_HASS_API: ["assist"],
subentry_id=subentry_id, }
subentry_type="conversation",
title="GPT-3.5 Turbo",
unique_id=None, async def test_create_conversation_agent_no_control(
) hass: HomeAssistant,
in mock_config_entry.subentries.values() mock_open_router_client: AsyncMock,
mock_openai_client: AsyncMock,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation agent without control over the LLM API."""
mock_config_entry.add_to_hass(hass)
await setup_integration(hass, mock_config_entry)
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": SOURCE_USER},
) )
assert result["type"] is FlowResultType.FORM
assert not result["errors"]
assert result["step_id"] == "user"
assert result["data_schema"].schema["model"].config["options"] == [
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
]
result = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
CONF_LLM_HASS_API: [],
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == {
CONF_MODEL: "gpt-3.5-turbo",
CONF_PROMPT: "you are an assistant",
}

View File

@ -3,16 +3,24 @@
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from freezegun import freeze_time from freezegun import freeze_time
from openai.types import CompletionUsage
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message_tool_call import Function
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from homeassistant.helpers import entity_registry as er, intent
from . import setup_integration from . import setup_integration
from tests.common import MockConfigEntry from tests.common import MockConfigEntry, snapshot_platform
from tests.components.conversation import MockChatLog, mock_chat_log # noqa: F401 from tests.components.conversation import MockChatLog, mock_chat_log # noqa: F401
@ -23,11 +31,23 @@ def freeze_the_time():
yield yield
@pytest.mark.parametrize("enable_assist", [True, False], ids=["assist", "no_assist"])
async def test_all_entities(
hass: HomeAssistant,
snapshot: SnapshotAssertion,
mock_openai_client: AsyncMock,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test all entities."""
await setup_integration(hass, mock_config_entry)
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
async def test_default_prompt( async def test_default_prompt(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
mock_openai_client: AsyncMock, mock_openai_client: AsyncMock,
mock_chat_log: MockChatLog, # noqa: F811 mock_chat_log: MockChatLog, # noqa: F811
@ -50,3 +70,95 @@ async def test_default_prompt(
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router", "HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
"X-Title": "Home Assistant", "X-Title": "Home Assistant",
} }
@pytest.mark.parametrize("enable_assist", [True])
async def test_function_call(
hass: HomeAssistant,
mock_chat_log: MockChatLog, # noqa: F811
mock_config_entry: MockConfigEntry,
snapshot: SnapshotAssertion,
mock_openai_client: AsyncMock,
) -> None:
"""Test function call from the assistant."""
await setup_integration(hass, mock_config_entry)
mock_chat_log.mock_tool_results(
{
"call_call_1": "value1",
"call_call_2": "value2",
}
)
async def completion_result(*args, messages, **kwargs):
for message in messages:
role = message["role"] if isinstance(message, dict) else message.role
if role == "tool":
return ChatCompletion(
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="I have successfully called the function",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
return ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="call_call_1",
function=Function(
arguments='{"param1":"call1"}',
name="test_tool",
),
type="function",
)
],
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
mock_openai_client.chat.completions.create = completion_result
result = await conversation.async_converse(
hass,
"Please call the test function",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.gpt_3_5_turbo",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
# Don't test the prompt, as it's not deterministic
assert mock_chat_log.content[1:] == snapshot