mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Add support for calling tools in Open Router (#148881)
This commit is contained in:
parent
073ea813f0
commit
50688bbd69
@ -16,8 +16,9 @@ from homeassistant.config_entries import (
|
|||||||
ConfigSubentryFlow,
|
ConfigSubentryFlow,
|
||||||
SubentryFlowResult,
|
SubentryFlowResult,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_API_KEY, CONF_MODEL
|
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers.httpx_client import get_async_client
|
from homeassistant.helpers.httpx_client import get_async_client
|
||||||
from homeassistant.helpers.selector import (
|
from homeassistant.helpers.selector import (
|
||||||
@ -25,9 +26,10 @@ from homeassistant.helpers.selector import (
|
|||||||
SelectSelector,
|
SelectSelector,
|
||||||
SelectSelectorConfig,
|
SelectSelectorConfig,
|
||||||
SelectSelectorMode,
|
SelectSelectorMode,
|
||||||
|
TemplateSelector,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import CONF_PROMPT, DOMAIN, RECOMMENDED_CONVERSATION_OPTIONS
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -90,6 +92,8 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
) -> SubentryFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""User flow to create a sensor subentry."""
|
"""User flow to create a sensor subentry."""
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
|
if not user_input.get(CONF_LLM_HASS_API):
|
||||||
|
user_input.pop(CONF_LLM_HASS_API, None)
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=self.options[user_input[CONF_MODEL]], data=user_input
|
title=self.options[user_input[CONF_MODEL]], data=user_input
|
||||||
)
|
)
|
||||||
@ -99,11 +103,17 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
api_key=entry.data[CONF_API_KEY],
|
api_key=entry.data[CONF_API_KEY],
|
||||||
http_client=get_async_client(self.hass),
|
http_client=get_async_client(self.hass),
|
||||||
)
|
)
|
||||||
|
hass_apis: list[SelectOptionDict] = [
|
||||||
|
SelectOptionDict(
|
||||||
|
label=api.name,
|
||||||
|
value=api.id,
|
||||||
|
)
|
||||||
|
for api in llm.async_get_apis(self.hass)
|
||||||
|
]
|
||||||
options = []
|
options = []
|
||||||
async for model in client.with_options(timeout=10.0).models.list():
|
async for model in client.with_options(timeout=10.0).models.list():
|
||||||
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
|
options.append(SelectOptionDict(value=model.id, label=model.name)) # type: ignore[attr-defined]
|
||||||
self.options[model.id] = model.name # type: ignore[attr-defined]
|
self.options[model.id] = model.name # type: ignore[attr-defined]
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user",
|
step_id="user",
|
||||||
data_schema=vol.Schema(
|
data_schema=vol.Schema(
|
||||||
@ -113,6 +123,20 @@ class ConversationFlowHandler(ConfigSubentryFlow):
|
|||||||
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
|
options=options, mode=SelectSelectorMode.DROPDOWN, sort=True
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_PROMPT,
|
||||||
|
description={
|
||||||
|
"suggested_value": RECOMMENDED_CONVERSATION_OPTIONS[
|
||||||
|
CONF_PROMPT
|
||||||
|
]
|
||||||
|
},
|
||||||
|
): TemplateSelector(),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_LLM_HASS_API,
|
||||||
|
default=RECOMMENDED_CONVERSATION_OPTIONS[CONF_LLM_HASS_API],
|
||||||
|
): SelectSelector(
|
||||||
|
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -2,5 +2,17 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
DOMAIN = "open_router"
|
DOMAIN = "open_router"
|
||||||
LOGGER = logging.getLogger(__package__)
|
LOGGER = logging.getLogger(__package__)
|
||||||
|
|
||||||
|
CONF_PROMPT = "prompt"
|
||||||
|
CONF_RECOMMENDED = "recommended"
|
||||||
|
|
||||||
|
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||||
|
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||||
|
}
|
||||||
|
@ -1,25 +1,39 @@
|
|||||||
"""Conversation support for OpenRouter."""
|
"""Conversation support for OpenRouter."""
|
||||||
|
|
||||||
from typing import Literal
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
import json
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
from openai import NOT_GIVEN
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam,
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionMessage,
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionMessageToolCallParam,
|
||||||
ChatCompletionSystemMessageParam,
|
ChatCompletionSystemMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
ChatCompletionToolParam,
|
||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import Function
|
||||||
|
from openai.types.shared_params import FunctionDefinition
|
||||||
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigSubentry
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, CONF_MODEL, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from . import OpenRouterConfigEntry
|
from . import OpenRouterConfigEntry
|
||||||
from .const import DOMAIN, LOGGER
|
from .const import CONF_PROMPT, DOMAIN, LOGGER
|
||||||
|
|
||||||
|
# Max number of back and forth with the LLM to generate a response
|
||||||
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
@ -35,13 +49,31 @@ async def async_setup_entry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool(
|
||||||
|
tool: llm.Tool,
|
||||||
|
custom_serializer: Callable[[Any], Any] | None,
|
||||||
|
) -> ChatCompletionToolParam:
|
||||||
|
"""Format tool specification."""
|
||||||
|
tool_spec = FunctionDefinition(
|
||||||
|
name=tool.name,
|
||||||
|
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||||
|
)
|
||||||
|
if tool.description:
|
||||||
|
tool_spec["description"] = tool.description
|
||||||
|
return ChatCompletionToolParam(type="function", function=tool_spec)
|
||||||
|
|
||||||
|
|
||||||
def _convert_content_to_chat_message(
|
def _convert_content_to_chat_message(
|
||||||
content: conversation.Content,
|
content: conversation.Content,
|
||||||
) -> ChatCompletionMessageParam | None:
|
) -> ChatCompletionMessageParam | None:
|
||||||
"""Convert any native chat message for this agent to the native format."""
|
"""Convert any native chat message for this agent to the native format."""
|
||||||
LOGGER.debug("_convert_content_to_chat_message=%s", content)
|
LOGGER.debug("_convert_content_to_chat_message=%s", content)
|
||||||
if isinstance(content, conversation.ToolResultContent):
|
if isinstance(content, conversation.ToolResultContent):
|
||||||
return None
|
return ChatCompletionToolMessageParam(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=content.tool_call_id,
|
||||||
|
content=json.dumps(content.tool_result),
|
||||||
|
)
|
||||||
|
|
||||||
role: Literal["user", "assistant", "system"] = content.role
|
role: Literal["user", "assistant", "system"] = content.role
|
||||||
if role == "system" and content.content:
|
if role == "system" and content.content:
|
||||||
@ -51,13 +83,55 @@ def _convert_content_to_chat_message(
|
|||||||
return ChatCompletionUserMessageParam(role="user", content=content.content)
|
return ChatCompletionUserMessageParam(role="user", content=content.content)
|
||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
return ChatCompletionAssistantMessageParam(
|
param = ChatCompletionAssistantMessageParam(
|
||||||
role="assistant", content=content.content
|
role="assistant",
|
||||||
|
content=content.content,
|
||||||
)
|
)
|
||||||
|
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||||
|
param["tool_calls"] = [
|
||||||
|
ChatCompletionMessageToolCallParam(
|
||||||
|
type="function",
|
||||||
|
id=tool_call.id,
|
||||||
|
function=Function(
|
||||||
|
arguments=json.dumps(tool_call.tool_args),
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tool_call in content.tool_calls
|
||||||
|
]
|
||||||
|
return param
|
||||||
LOGGER.warning("Could not convert message to Completions API: %s", content)
|
LOGGER.warning("Could not convert message to Completions API: %s", content)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_tool_arguments(arguments: str) -> Any:
|
||||||
|
"""Decode tool call arguments."""
|
||||||
|
try:
|
||||||
|
return json.loads(arguments)
|
||||||
|
except json.JSONDecodeError as err:
|
||||||
|
raise HomeAssistantError(f"Unexpected tool argument response: {err}") from err
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_response(
|
||||||
|
message: ChatCompletionMessage,
|
||||||
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
|
"""Transform the OpenRouter message to a ChatLog format."""
|
||||||
|
data: conversation.AssistantContentDeltaDict = {
|
||||||
|
"role": message.role,
|
||||||
|
"content": message.content,
|
||||||
|
}
|
||||||
|
if message.tool_calls:
|
||||||
|
data["tool_calls"] = [
|
||||||
|
llm.ToolInput(
|
||||||
|
id=tool_call.id,
|
||||||
|
tool_name=tool_call.function.name,
|
||||||
|
tool_args=_decode_tool_arguments(tool_call.function.arguments),
|
||||||
|
)
|
||||||
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterConversationEntity(conversation.ConversationEntity):
|
class OpenRouterConversationEntity(conversation.ConversationEntity):
|
||||||
"""OpenRouter conversation agent."""
|
"""OpenRouter conversation agent."""
|
||||||
|
|
||||||
@ -75,6 +149,10 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|||||||
name=subentry.title,
|
name=subentry.title,
|
||||||
entry_type=DeviceEntryType.SERVICE,
|
entry_type=DeviceEntryType.SERVICE,
|
||||||
)
|
)
|
||||||
|
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||||
|
self._attr_supported_features = (
|
||||||
|
conversation.ConversationEntityFeature.CONTROL
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||||
@ -93,12 +171,19 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|||||||
await chat_log.async_provide_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
user_input.as_llm_context(DOMAIN),
|
user_input.as_llm_context(DOMAIN),
|
||||||
options.get(CONF_LLM_HASS_API),
|
options.get(CONF_LLM_HASS_API),
|
||||||
None,
|
options.get(CONF_PROMPT),
|
||||||
user_input.extra_system_prompt,
|
user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
except conversation.ConverseError as err:
|
except conversation.ConverseError as err:
|
||||||
return err.as_conversation_result()
|
return err.as_conversation_result()
|
||||||
|
|
||||||
|
tools: list[ChatCompletionToolParam] | None = None
|
||||||
|
if chat_log.llm_api:
|
||||||
|
tools = [
|
||||||
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
|
for tool in chat_log.llm_api.tools
|
||||||
|
]
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
m
|
m
|
||||||
for content in chat_log.content
|
for content in chat_log.content
|
||||||
@ -107,10 +192,12 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|||||||
|
|
||||||
client = self.entry.runtime_data
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
try:
|
||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
tools=tools or NOT_GIVEN,
|
||||||
user=chat_log.conversation_id,
|
user=chat_log.conversation_id,
|
||||||
extra_headers={
|
extra_headers={
|
||||||
"X-Title": "Home Assistant",
|
"X-Title": "Home Assistant",
|
||||||
@ -123,11 +210,16 @@ class OpenRouterConversationEntity(conversation.ConversationEntity):
|
|||||||
|
|
||||||
result_message = result.choices[0].message
|
result_message = result.choices[0].message
|
||||||
|
|
||||||
chat_log.async_add_assistant_content_without_tools(
|
messages.extend(
|
||||||
conversation.AssistantContent(
|
[
|
||||||
agent_id=user_input.agent_id,
|
msg
|
||||||
content=result_message.content,
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
|
user_input.agent_id, _transform_response(result_message)
|
||||||
)
|
)
|
||||||
|
if (msg := _convert_content_to_chat_message(content))
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
if not chat_log.unresponded_tool_results:
|
||||||
|
break
|
||||||
|
|
||||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||||
|
@ -24,7 +24,13 @@
|
|||||||
"user": {
|
"user": {
|
||||||
"description": "Configure the new conversation agent",
|
"description": "Configure the new conversation agent",
|
||||||
"data": {
|
"data": {
|
||||||
"model": "Model"
|
"model": "Model",
|
||||||
|
"prompt": "Instructions",
|
||||||
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||||
|
},
|
||||||
|
"data_description": {
|
||||||
|
"model": "The model to use for the conversation agent",
|
||||||
|
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from openai.types import CompletionUsage
|
from openai.types import CompletionUsage
|
||||||
@ -9,10 +10,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
|||||||
from openai.types.chat.chat_completion import Choice
|
from openai.types.chat.chat_completion import Choice
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.open_router.const import DOMAIN
|
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||||
from homeassistant.config_entries import ConfigSubentryData
|
from homeassistant.config_entries import ConfigSubentryData
|
||||||
from homeassistant.const import CONF_API_KEY, CONF_MODEL
|
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
@ -29,7 +31,27 @@ def mock_setup_entry() -> Generator[AsyncMock]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
def enable_assist() -> bool:
|
||||||
|
"""Mock conversation subentry data."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
||||||
|
"""Mock conversation subentry data."""
|
||||||
|
res: dict[str, Any] = {
|
||||||
|
CONF_MODEL: "gpt-3.5-turbo",
|
||||||
|
CONF_PROMPT: "You are a helpful assistant.",
|
||||||
|
}
|
||||||
|
if enable_assist:
|
||||||
|
res[CONF_LLM_HASS_API] = [llm.LLM_API_ASSIST]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_entry(
|
||||||
|
hass: HomeAssistant, conversation_subentry_data: dict[str, Any]
|
||||||
|
) -> MockConfigEntry:
|
||||||
"""Mock a config entry."""
|
"""Mock a config entry."""
|
||||||
return MockConfigEntry(
|
return MockConfigEntry(
|
||||||
title="OpenRouter",
|
title="OpenRouter",
|
||||||
@ -39,7 +61,7 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|||||||
},
|
},
|
||||||
subentries_data=[
|
subentries_data=[
|
||||||
ConfigSubentryData(
|
ConfigSubentryData(
|
||||||
data={CONF_MODEL: "gpt-3.5-turbo"},
|
data=conversation_subentry_data,
|
||||||
subentry_id="ABCDEF",
|
subentry_id="ABCDEF",
|
||||||
subentry_type="conversation",
|
subentry_type="conversation",
|
||||||
title="GPT-3.5 Turbo",
|
title="GPT-3.5 Turbo",
|
||||||
|
@ -1,4 +1,108 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-entry]
|
||||||
|
EntityRegistryEntrySnapshot({
|
||||||
|
'aliases': set({
|
||||||
|
}),
|
||||||
|
'area_id': None,
|
||||||
|
'capabilities': None,
|
||||||
|
'config_entry_id': <ANY>,
|
||||||
|
'config_subentry_id': <ANY>,
|
||||||
|
'device_class': None,
|
||||||
|
'device_id': <ANY>,
|
||||||
|
'disabled_by': None,
|
||||||
|
'domain': 'conversation',
|
||||||
|
'entity_category': None,
|
||||||
|
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||||
|
'has_entity_name': True,
|
||||||
|
'hidden_by': None,
|
||||||
|
'icon': None,
|
||||||
|
'id': <ANY>,
|
||||||
|
'labels': set({
|
||||||
|
}),
|
||||||
|
'name': None,
|
||||||
|
'options': dict({
|
||||||
|
'conversation': dict({
|
||||||
|
'should_expose': False,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'original_device_class': None,
|
||||||
|
'original_icon': None,
|
||||||
|
'original_name': None,
|
||||||
|
'platform': 'open_router',
|
||||||
|
'previous_unique_id': None,
|
||||||
|
'suggested_object_id': None,
|
||||||
|
'supported_features': <ConversationEntityFeature: 1>,
|
||||||
|
'translation_key': None,
|
||||||
|
'unique_id': 'ABCDEF',
|
||||||
|
'unit_of_measurement': None,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_all_entities[assist][conversation.gpt_3_5_turbo-state]
|
||||||
|
StateSnapshot({
|
||||||
|
'attributes': ReadOnlyDict({
|
||||||
|
'friendly_name': 'GPT-3.5 Turbo',
|
||||||
|
'supported_features': <ConversationEntityFeature: 1>,
|
||||||
|
}),
|
||||||
|
'context': <ANY>,
|
||||||
|
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||||
|
'last_changed': <ANY>,
|
||||||
|
'last_reported': <ANY>,
|
||||||
|
'last_updated': <ANY>,
|
||||||
|
'state': 'unknown',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-entry]
|
||||||
|
EntityRegistryEntrySnapshot({
|
||||||
|
'aliases': set({
|
||||||
|
}),
|
||||||
|
'area_id': None,
|
||||||
|
'capabilities': None,
|
||||||
|
'config_entry_id': <ANY>,
|
||||||
|
'config_subentry_id': <ANY>,
|
||||||
|
'device_class': None,
|
||||||
|
'device_id': <ANY>,
|
||||||
|
'disabled_by': None,
|
||||||
|
'domain': 'conversation',
|
||||||
|
'entity_category': None,
|
||||||
|
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||||
|
'has_entity_name': True,
|
||||||
|
'hidden_by': None,
|
||||||
|
'icon': None,
|
||||||
|
'id': <ANY>,
|
||||||
|
'labels': set({
|
||||||
|
}),
|
||||||
|
'name': None,
|
||||||
|
'options': dict({
|
||||||
|
'conversation': dict({
|
||||||
|
'should_expose': False,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'original_device_class': None,
|
||||||
|
'original_icon': None,
|
||||||
|
'original_name': None,
|
||||||
|
'platform': 'open_router',
|
||||||
|
'previous_unique_id': None,
|
||||||
|
'suggested_object_id': None,
|
||||||
|
'supported_features': 0,
|
||||||
|
'translation_key': None,
|
||||||
|
'unique_id': 'ABCDEF',
|
||||||
|
'unit_of_measurement': None,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_all_entities[no_assist][conversation.gpt_3_5_turbo-state]
|
||||||
|
StateSnapshot({
|
||||||
|
'attributes': ReadOnlyDict({
|
||||||
|
'friendly_name': 'GPT-3.5 Turbo',
|
||||||
|
'supported_features': <ConversationEntityFeature: 0>,
|
||||||
|
}),
|
||||||
|
'context': <ANY>,
|
||||||
|
'entity_id': 'conversation.gpt_3_5_turbo',
|
||||||
|
'last_changed': <ANY>,
|
||||||
|
'last_reported': <ANY>,
|
||||||
|
'last_updated': <ANY>,
|
||||||
|
'state': 'unknown',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
# name: test_default_prompt
|
# name: test_default_prompt
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
@ -14,3 +118,39 @@
|
|||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_function_call[True]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'attachments': None,
|
||||||
|
'content': 'Please call the test function',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||||
|
'content': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': list([
|
||||||
|
dict({
|
||||||
|
'id': 'call_call_1',
|
||||||
|
'tool_args': dict({
|
||||||
|
'param1': 'call1',
|
||||||
|
}),
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||||
|
'role': 'tool_result',
|
||||||
|
'tool_call_id': 'call_call_1',
|
||||||
|
'tool_name': 'test_tool',
|
||||||
|
'tool_result': 'value1',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||||
|
'content': 'I have successfully called the function',
|
||||||
|
'role': 'assistant',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
@ -5,9 +5,9 @@ from unittest.mock import AsyncMock
|
|||||||
import pytest
|
import pytest
|
||||||
from python_open_router import OpenRouterError
|
from python_open_router import OpenRouterError
|
||||||
|
|
||||||
from homeassistant.components.open_router.const import DOMAIN
|
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||||
from homeassistant.config_entries import SOURCE_USER, ConfigSubentry
|
from homeassistant.config_entries import SOURCE_USER
|
||||||
from homeassistant.const import CONF_API_KEY, CONF_MODEL
|
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
@ -129,18 +129,56 @@ async def test_create_conversation_agent(
|
|||||||
|
|
||||||
result = await hass.config_entries.subentries.async_configure(
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{CONF_MODEL: "gpt-3.5-turbo"},
|
{
|
||||||
|
CONF_MODEL: "gpt-3.5-turbo",
|
||||||
|
CONF_PROMPT: "you are an assistant",
|
||||||
|
CONF_LLM_HASS_API: ["assist"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
subentry_id = list(mock_config_entry.subentries)[0]
|
assert result["data"] == {
|
||||||
assert (
|
CONF_MODEL: "gpt-3.5-turbo",
|
||||||
ConfigSubentry(
|
CONF_PROMPT: "you are an assistant",
|
||||||
data={CONF_MODEL: "gpt-3.5-turbo"},
|
CONF_LLM_HASS_API: ["assist"],
|
||||||
subentry_id=subentry_id,
|
}
|
||||||
subentry_type="conversation",
|
|
||||||
title="GPT-3.5 Turbo",
|
|
||||||
unique_id=None,
|
async def test_create_conversation_agent_no_control(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_open_router_client: AsyncMock,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a conversation agent without control over the LLM API."""
|
||||||
|
|
||||||
|
mock_config_entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "conversation"),
|
||||||
|
context={"source": SOURCE_USER},
|
||||||
)
|
)
|
||||||
in mock_config_entry.subentries.values()
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert not result["errors"]
|
||||||
|
assert result["step_id"] == "user"
|
||||||
|
|
||||||
|
assert result["data_schema"].schema["model"].config["options"] == [
|
||||||
|
{"value": "gpt-3.5-turbo", "label": "GPT-3.5 Turbo"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
CONF_MODEL: "gpt-3.5-turbo",
|
||||||
|
CONF_PROMPT: "you are an assistant",
|
||||||
|
CONF_LLM_HASS_API: [],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result["data"] == {
|
||||||
|
CONF_MODEL: "gpt-3.5-turbo",
|
||||||
|
CONF_PROMPT: "you are an assistant",
|
||||||
|
}
|
||||||
|
@ -3,16 +3,24 @@
|
|||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from openai.types import CompletionUsage
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion import Choice
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
|
from homeassistant.helpers import entity_registry as er, intent
|
||||||
|
|
||||||
from . import setup_integration
|
from . import setup_integration
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry, snapshot_platform
|
||||||
from tests.components.conversation import MockChatLog, mock_chat_log # noqa: F401
|
from tests.components.conversation import MockChatLog, mock_chat_log # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@ -23,11 +31,23 @@ def freeze_the_time():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_assist", [True, False], ids=["assist", "no_assist"])
|
||||||
|
async def test_all_entities(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test all entities."""
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
await snapshot_platform(hass, entity_registry, snapshot, mock_config_entry.entry_id)
|
||||||
|
|
||||||
|
|
||||||
async def test_default_prompt(
|
async def test_default_prompt(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
area_registry: ar.AreaRegistry,
|
|
||||||
device_registry: dr.DeviceRegistry,
|
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
mock_openai_client: AsyncMock,
|
mock_openai_client: AsyncMock,
|
||||||
mock_chat_log: MockChatLog, # noqa: F811
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
@ -50,3 +70,95 @@ async def test_default_prompt(
|
|||||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||||
"X-Title": "Home Assistant",
|
"X-Title": "Home Assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_assist", [True])
|
||||||
|
async def test_function_call(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
mock_openai_client: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test function call from the assistant."""
|
||||||
|
await setup_integration(hass, mock_config_entry)
|
||||||
|
|
||||||
|
mock_chat_log.mock_tool_results(
|
||||||
|
{
|
||||||
|
"call_call_1": "value1",
|
||||||
|
"call_call_2": "value2",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def completion_result(*args, messages, **kwargs):
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"] if isinstance(message, dict) else message.role
|
||||||
|
if role == "tool":
|
||||||
|
return ChatCompletion(
|
||||||
|
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content="I have successfully called the function",
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatCompletion(
|
||||||
|
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
index=0,
|
||||||
|
message=ChatCompletionMessage(
|
||||||
|
content=None,
|
||||||
|
role="assistant",
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=[
|
||||||
|
ChatCompletionMessageToolCall(
|
||||||
|
id="call_call_1",
|
||||||
|
function=Function(
|
||||||
|
arguments='{"param1":"call1"}',
|
||||||
|
name="test_tool",
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1700000000,
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint=None,
|
||||||
|
usage=CompletionUsage(
|
||||||
|
completion_tokens=9, prompt_tokens=8, total_tokens=17
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create = completion_result
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
mock_chat_log.conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id="conversation.gpt_3_5_turbo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
# Don't test the prompt, as it's not deterministic
|
||||||
|
assert mock_chat_log.content[1:] == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user