mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Add support for calling tools in Open Router (#148881)
This commit is contained in:
parent
073ea813f0
commit
50688bbd69
@ -16,8 +16,9 @@ from homeassistant.config_entries import (
|
||||
ConfigSubentryFlow,
|
||||
SubentryFlowResult,
|
||||
)
|
||||
from homeassistant.const import CONF_API_KEY, CONF_MODEL
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.httpx_client import get_async_client
|
||||
from homeassistant.helpers.selector import (
|
||||
@ -25,9 +26,10 @@ from homeassistant.helpers.selector import (
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
SelectSelectorMode,
|
||||
TemplateSelector,
|
||||
)
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import CONF_PROMPT, DOMAIN, RECOMMENDED_CONVERSATION_OPTIONS
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -90,6 +92,8 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
) -> SubentryFlowResult:
|
||||
"""User flow to create a sensor subentry."""
|
||||
if user_input is not None:
|
||||
if not user_input.get(CONF_LLM_HASS_API):
|
||||
user_input.pop(CONF_LLM_HASS_API, None)
|
||||
return self.async_create_entry(
|
||||
title=self.options[user_input[CONF_MODEL]], data=user_input
|
||||
)
|
||||
@ -99,11 +103,17 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
api_key=entry.data[CONF_API_KEY],
|
||||
http_client=get_async_client(self.hass),
|
||||
)
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(self.hass)
|
||||
]
|
||||
options = []
|
||||
async for model in client.with_options(timeout=10.0).models.list():
|
||||
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
|
||||
self.options[model.id] = model.name # type: ignore[attr-defined]
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=vol.Schema(
|
||||
@ -113,6 +123,20 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
||||
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
|
||||
),
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": RECOMMENDED_CONVERSATION_OPTIONS[
|
||||
CONF_PROMPT
|
||||
]
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
default=RECOMMENDED_CONVERSATION_OPTIONS[CONF_LLM_HASS_API],
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -2,5 +2,17 @@
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
DOMAIN = "open_router"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
|
||||
CONF_PROMPT = "prompt"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
}
|
||||
|
@ -1,25 +1,39 @@
|
||||
"""Conversation support for OpenRouter."""
|
||||
|
||||
from typing import Literal
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
import openai
|
||||
from openai import NOT_GIVEN
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||
from openai.types.shared_params import FunctionDefinition
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from . import OpenRouterConfigEntry
|
||||
from .const import DOMAIN, LOGGER
|
||||
from .const import CONF_PROMPT, DOMAIN, LOGGER
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@ -35,13 +49,31 @@ async def async_setup_entry(
|
||||
)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool,
|
||||
custom_serializer: Callable[[Any], Any] | None,
|
||||
) -> ChatCompletionToolParam:
|
||||
"""Format tool specification."""
|
||||
tool_spec = FunctionDefinition(
|
||||
name=tool.name,
|
||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
)
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||
|
||||
|
||||
def _convert_content_to_chat_message(
|
||||
content: conversation.Content,
|
||||
) -> ChatCompletionMessageParam | None:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
LOGGER.debug("_convert_content_to_chat_message=%s", content)
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
return None
|
||||
return ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
tool_call_id=content.tool_call_id,
|
||||
content=json.dumps(content.tool_result),
|
||||
)
|
||||
|
||||
role: Literal["user", "assistant", "system"] = content.role
|
||||
if role == "system" and content.content:
|
||||
@ -51,13 +83,55 @@ def _convert_content_to_chat_message(
|
||||
return ChatCompletionUserMessageParam(role="user", content=content.content)
|
||||
|
||||
if role == "assistant":
|
||||
return ChatCompletionAssistantMessageParam(
|
||||
role="assistant", content=content.content
|
||||
param = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content.content,
|
||||
)
|
||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||
param["tool_calls"] = [
|
||||
ChatCompletionMessageToolCallParam(
|
||||
type="function",
|
||||
id=tool_call.id,
|
||||
function=Function(
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
name=tool_call.tool_name,
|
||||
),
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
]
|
||||
return param
|
||||
LOGGER.warning("Could not convert message to Completions API: %s", content)
|
||||
return None
|
||||
|
||||
|
||||
def _decode_tool_arguments(arguments: str) -> Any:
|
||||
"""Decode tool call arguments."""
|
||||
try:
|
||||
return json.loads(arguments)
|
||||
except json.JSONDecodeError as err:
|
||||
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
|
||||
|
||||
|
||||
async def _transform_response(
|
||||
message: ChatCompletionMessage,
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the OpenRouter message to a ChatLog format."""
|
||||
data: conversation.AssistantContentDeltaDict = {
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
}
|
||||
if message.tool_calls:
|
||||
data["tool_calls"] = [
|
||||
llm.ToolInput(
|
||||
id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
tool_args=_decode_tool_arguments(tool_call.function.arguments),
|
||||
)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
yield data
|
||||
|
||||
|
||||
class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
"""OpenRouter conversation agent."""
|
||||
|
||||
@ -75,6 +149,10 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
name=subentry.title,
|
||||
entry_type=DeviceEntryType.SERVICE,
|
||||
)
|
||||
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
)
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
@ -93,12 +171,19 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
await chat_log.async_provide_llm_data(
|
||||
user_input.as_llm_context(DOMAIN),
|
||||
options.get(CONF_LLM_HASS_API),
|
||||
None,
|
||||
options.get(CONF_PROMPT),
|
||||
user_input.extra_system_prompt,
|
||||
)
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
||||
tools: list[ChatCompletionToolParam] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
messages = [
|
||||
m
|
||||
for content in chat_log.content
|
||||
@ -107,10 +192,12 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
|
||||
client = self.entry.runtime_data
|
||||
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
tools=tools or NOT_GIVEN,
|
||||
user=chat_log.conversation_id,
|
||||
extra_headers={
|
||||
"X-Title": "Home Assistant",
|
||||
@ -123,11 +210,16 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||
|
||||
result_message = result.choices[0].message
|
||||
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=user_input.agent_id,
|
||||
content=result_message.content,
|
||||
messages.extend(
|
||||
[
|
||||
msg
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
user_input.agent_id, _transform_response(result_message)
|
||||
)
|
||||
if (msg := _convert_content_to_chat_message(content))
|
||||
]
|
||||
)
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
|
@ -24,7 +24,13 @@
|
||||
"user": {
|
||||
"description": "Configure the new conversation agent",
|
||||
"data": {
|
||||
"model": "Model"
|
||||
"model": "Model",
|
||||
"prompt": "Instructions",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||
},
|
||||
"data_description": {
|
||||
"model": "The model to use for the conversation agent",
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
@ -9,10 +10,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.open_router.const import DOMAIN
|
||||
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||
from homeassistant.config_entries import ConfigSubentryData
|
||||
from homeassistant.const import CONF_API_KEY, CONF_MODEL
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
@ -29,7 +31,27 @@ def mock_setup_entry() -> Generator[AsyncMock]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
def enable_assist() -> bool:
|
||||
"""Mock conversation subentry data."""
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
||||
"""Mock conversation subentry data."""
|
||||
res: dict[str, Any] = {
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_PROMPT: "You are a helpful assistant.",
|
||||
}
|
||||
if enable_assist:
|
||||
res[CONF_LLM_HASS_API] = [llm.LLM_API_ASSIST]
|
||||
return res
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(
|
||||
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry."""
|
||||
return MockConfigEntry(
|
||||
title="OpenRouter",
|
||||
@ -39,7 +61,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
},
|
||||
subentries_data=[
|
||||
ConfigSubentryData(
|
||||
data={CONF_MODEL: "gpt-3.5-turbo"},
|
||||
data=conversation_subentry_data,
|
||||
subentry_id="ABCDEF",
|
||||
subentry_type="conversation",
|
||||
title="GPT-3.5 Turbo",
|
||||
|
@ -1,4 +1,108 @@
|
||||
# serializer version: 1
|
||||
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-entry]
|
||||
EntityRegistryEntrySnapshot({
|
||||
'aliases': set({
|
||||
}),
|
||||
'area_id': None,
|
||||
'capabilities': None,
|
||||
'config_entry_id': <ANY>,
|
||||
'config_subentry_id': <ANY>,
|
||||
'device_class': None,
|
||||
'device_id': <ANY>,
|
||||
'disabled_by': None,
|
||||
'domain': 'conversation',
|
||||
'entity_category': None,
|
||||
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||
'has_entity_name': True,
|
||||
'hidden_by': None,
|
||||
'icon': None,
|
||||
'id': <ANY>,
|
||||
'labels': set({
|
||||
}),
|
||||
'name': None,
|
||||
'options': dict({
|
||||
'conversation': dict({
|
||||
'should_expose': False,
|
||||
}),
|
||||
}),
|
||||
'original_device_class': None,
|
||||
'original_icon': None,
|
||||
'original_name': None,
|
||||
'platform': 'open_router',
|
||||
'previous_unique_id': None,
|
||||
'suggested_object_id': None,
|
||||
'supported_features': <ConversationEntityFeature: 1>,
|
||||
'translation_key': None,
|
||||
'unique_id': 'ABCDEF',
|
||||
'unit_of_measurement': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-state]
|
||||
StateSnapshot({
|
||||
'attributes': ReadOnlyDict({
|
||||
'friendly_name': 'GPT-3.5 Turbo',
|
||||
'supported_features': <ConversationEntityFeature: 1>,
|
||||
}),
|
||||
'context': <ANY>,
|
||||
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||
'last_changed': <ANY>,
|
||||
'last_reported': <ANY>,
|
||||
'last_updated': <ANY>,
|
||||
'state': 'unknown',
|
||||
})
|
||||
# ---
|
||||
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-entry]
|
||||
EntityRegistryEntrySnapshot({
|
||||
'aliases': set({
|
||||
}),
|
||||
'area_id': None,
|
||||
'capabilities': None,
|
||||
'config_entry_id': <ANY>,
|
||||
'config_subentry_id': <ANY>,
|
||||
'device_class': None,
|
||||
'device_id': <ANY>,
|
||||
'disabled_by': None,
|
||||
'domain': 'conversation',
|
||||
'entity_category': None,
|
||||
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||
'has_entity_name': True,
|
||||
'hidden_by': None,
|
||||
'icon': None,
|
||||
'id': <ANY>,
|
||||
'labels': set({
|
||||
}),
|
||||
'name': None,
|
||||
'options': dict({
|
||||
'conversation': dict({
|
||||
'should_expose': False,
|
||||
}),
|
||||
}),
|
||||
'original_device_class': None,
|
||||
'original_icon': None,
|
||||
'original_name': None,
|
||||
'platform': 'open_router',
|
||||
'previous_unique_id': None,
|
||||
'suggested_object_id': None,
|
||||
'supported_features': 0,
|
||||
'translation_key': None,
|
||||
'unique_id': 'ABCDEF',
|
||||
'unit_of_measurement': None,
|
||||
})
|
||||
# ---
|
||||
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-state]
|
||||
StateSnapshot({
|
||||
'attributes': ReadOnlyDict({
|
||||
'friendly_name': 'GPT-3.5 Turbo',
|
||||
'supported_features': <ConversationEntityFeature: 0>,
|
||||
}),
|
||||
'context': <ANY>,
|
||||
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||
'last_changed': <ANY>,
|
||||
'last_reported': <ANY>,
|
||||
'last_updated': <ANY>,
|
||||
'state': 'unknown',
|
||||
})
|
||||
# ---
|
||||
# name: test_default_prompt
|
||||
list([
|
||||
dict({
|
||||
@ -14,3 +118,39 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_function_call[True]
|
||||
list([
|
||||
dict({
|
||||
'attachments': None,
|
||||
'content': 'Please call the test function',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||
'content': None,
|
||||
'role': 'assistant',
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'call_call_1',
|
||||
'tool_args': dict({
|
||||
'param1': 'call1',
|
||||
}),
|
||||
'tool_name': 'test_tool',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||
'role': 'tool_result',
|
||||
'tool_call_id': 'call_call_1',
|
||||
'tool_name': 'test_tool',
|
||||
'tool_result': 'value1',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||
'content': 'I have successfully called the function',
|
||||
'role': 'assistant',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
@ -5,9 +5,9 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from python_open_router import OpenRouterError
|
||||
|
||||
from homeassistant.components.open_router.const import DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_USER, ConfigSubentry
|
||||
from homeassistant.const import CONF_API_KEY, CONF_MODEL
|
||||
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_USER
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
@ -129,18 +129,56 @@ async def test_create_conversation_agent(
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_MODEL: "gpt-3.5-turbo"},
|
||||
{
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
subentry_id = list(mock_config_entry.subentries)[0]
|
||||
assert (
|
||||
ConfigSubentry(
|
||||
data={CONF_MODEL: "gpt-3.5-turbo"},
|
||||
subentry_id=subentry_id,
|
||||
subentry_type="conversation",
|
||||
title="GPT-3.5 Turbo",
|
||||
unique_id=None,
|
||||
assert result["data"] == {
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
}
|
||||
|
||||
|
||||
async def test_create_conversation_agent_no_control(
|
||||
hass: HomeAssistant,
|
||||
mock_open_router_client: AsyncMock,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating a conversation agent without control over the LLM API."""
|
||||
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
in mock_config_entry.subentries.values()
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert not result["errors"]
|
||||
assert result["step_id"] == "user"
|
||||
|
||||
assert result["data_schema"].schema["model"].config["options"] == [
|
||||
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
|
||||
]
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: [],
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {
|
||||
CONF_MODEL: "gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
}
|
||||
|
@ -3,16 +3,24 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from freezegun import freeze_time
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
||||
from homeassistant.helpers import entity_registry as er, intent
|
||||
|
||||
from . import setup_integration
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.common import MockConfigEntry, snapshot_platform
|
||||
from tests.components.conversation import MockChatLog, mock_chat_log # noqa: F401
|
||||
|
||||
|
||||
@ -23,11 +31,23 @@ def freeze_the_time():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_assist", [True, False], ids=["assist", "no_assist"])
|
||||
async def test_all_entities(
|
||||
hass: HomeAssistant,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test all entities."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
||||
|
||||
|
||||
async def test_default_prompt(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
@ -50,3 +70,95 @@ async def test_default_prompt(
|
||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||
"X-Title": "Home Assistant",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_assist", [True])
|
||||
async def test_function_call(
|
||||
hass: HomeAssistant,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_config_entry: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_openai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test function call from the assistant."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
mock_chat_log.mock_tool_results(
|
||||
{
|
||||
"call_call_1": "value1",
|
||||
"call_call_2": "value2",
|
||||
}
|
||||
)
|
||||
|
||||
async def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
role = message["role"] if isinstance(message, dict) else message.role
|
||||
if role == "tool":
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="I have successfully called the function",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
|
||||
return ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_call_1",
|
||||
function=Function(
|
||||
arguments='{"param1":"call1"}',
|
||||
name="test_tool",
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-4-1106-preview",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(
|
||||
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||
),
|
||||
)
|
||||
|
||||
mock_openai_client.chat.completions.create = completion_result
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.gpt_3_5_turbo",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
# Don't test the prompt, as it's not deterministic
|
||||
assert mock_chat_log.content[1:] == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user