LLM Tools support for OpenAI integration (#117645)

* initial commit

* Add tests

* Move tests to the correct file

* Fix exception type

* Undo change to default prompt

* Add intent dependency

* Move format_tool out of the class

* Fix tests

* coverage

* Adjust to new API

* Update strings

* Update tests

* Remove unrelated change

* Test referencing non-existing API

* Add test to verify no exception on tool conversion for Assist tools

* Bump voluptuous-openapi==0.0.4

* Add device_id to tool input

* Fix tests

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Denis Shulyaka 2024-05-22 05:45:04 +03:00 committed by GitHub
parent 09213d8933
commit 2f0215b034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 665 additions and 79 deletions

View File

@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation", "documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.3"] "requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.4"]
} }

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import types
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import Any
@ -16,11 +15,15 @@ from homeassistant.config_entries import (
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, OptionsFlow,
) )
from homeassistant.const import CONF_API_KEY from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
NumberSelector, NumberSelector,
NumberSelectorConfig, NumberSelectorConfig,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
TemplateSelector, TemplateSelector,
) )
@ -46,16 +49,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
DEFAULT_OPTIONS = types.MappingProxyType(
{
CONF_PROMPT: DEFAULT_PROMPT,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
}
)
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
@ -92,7 +85,11 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
_LOGGER.exception("Unexpected exception") _LOGGER.exception("Unexpected exception")
errors["base"] = "unknown" errors["base"] = "unknown"
else: else:
return self.async_create_entry(title="OpenAI Conversation", data=user_input) return self.async_create_entry(
title="OpenAI Conversation",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
)
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
@ -118,45 +115,67 @@ class OpenAIOptionsFlow(OptionsFlow):
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Manage the options.""" """Manage the options."""
if user_input is not None: if user_input is not None:
return self.async_create_entry(title="OpenAI Conversation", data=user_input) if user_input[CONF_LLM_HASS_API] == "none":
schema = openai_config_option_schema(self.config_entry.options) user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = openai_config_option_schema(self.hass, self.config_entry.options)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init",
data_schema=vol.Schema(schema), data_schema=vol.Schema(schema),
) )
def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict: def openai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
) -> dict:
"""Return a schema for OpenAI completion options.""" """Return a schema for OpenAI completion options."""
if not options: apis: list[SelectOptionDict] = [
options = DEFAULT_OPTIONS SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
return { return {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options[CONF_PROMPT]},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
description={ description={
# New key in HA 2023.4 # New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) "suggested_value": options.get(CONF_CHAT_MODEL)
}, },
default=DEFAULT_CHAT_MODEL, default=DEFAULT_CHAT_MODEL,
): str, ): str,
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]}, description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS, default=DEFAULT_MAX_TOKENS,
): int, ): int,
vol.Optional( vol.Optional(
CONF_TOP_P, CONF_TOP_P,
description={"suggested_value": options[CONF_TOP_P]}, description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P, default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional( vol.Optional(
CONF_TEMPERATURE, CONF_TEMPERATURE,
description={"suggested_value": options[CONF_TEMPERATURE]}, description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE, default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
} }

View File

@ -21,10 +21,6 @@ An overview of the areas and the devices in this smart home:
{%- endif %} {%- endif %}
{%- endfor %} {%- endfor %}
{%- endfor %} {%- endfor %}
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
""" """
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-3.5-turbo" DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"

View File

@ -1,15 +1,18 @@
"""Conversation support for OpenAI.""" """Conversation support for OpenAI."""
from typing import Literal import json
from typing import Any, Literal
import openai import openai
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, template from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid from homeassistant.util import ulid
@ -28,6 +31,9 @@ from .const import (
LOGGER, LOGGER,
) )
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
@ -39,6 +45,15 @@ async def async_setup_entry(
async_add_entities([agent]) async_add_entities([agent])
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {"name": tool.name}
if tool.description:
tool_spec["description"] = tool.description
tool_spec["parameters"] = convert(tool.parameters)
return {"type": "function", "function": tool_spec}
class OpenAIConversationEntity( class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent conversation.ConversationEntity, conversation.AbstractConversationAgent
): ):
@ -75,6 +90,26 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
@ -87,7 +122,10 @@ class OpenAIConversationEntity(
else: else:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
try: try:
prompt = self._async_generate_prompt(raw_prompt) prompt = self._async_generate_prompt(
raw_prompt,
llm_api,
)
except TemplateError as err: except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err) LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
@ -106,38 +144,88 @@ class OpenAIConversationEntity(
client = self.hass.data[DOMAIN][self.entry.entry_id] client = self.hass.data[DOMAIN][self.entry.entry_id]
try: # To prevent infinite loops, we limit the number of iterations
result = await client.chat.completions.create( for _iteration in range(MAX_TOOL_ITERATIONS):
model=model, try:
messages=messages, result = await client.chat.completions.create(
max_tokens=max_tokens, model=model,
top_p=top_p, messages=messages,
temperature=temperature, tools=tools,
user=conversation_id, max_tokens=max_tokens,
) top_p=top_p,
except openai.OpenAIError as err: temperature=temperature,
intent_response = intent.IntentResponse(language=user_input.language) user=conversation_id,
intent_response.async_set_error( )
intent.IntentResponseErrorCode.UNKNOWN, except openai.OpenAIError as err:
f"Sorry, I had a problem talking to OpenAI: {err}", intent_response = intent.IntentResponse(language=user_input.language)
) intent_response.async_set_error(
return conversation.ConversationResult( intent.IntentResponseErrorCode.UNKNOWN,
response=intent_response, conversation_id=conversation_id f"Sorry, I had a problem talking to OpenAI: {err}",
) )
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
LOGGER.debug("Response %s", result)
response = result.choices[0].message
messages.append(response)
tool_calls = response.tool_calls
if not tool_calls or not llm_api:
break
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call.function.name,
"content": json.dumps(tool_response),
}
)
LOGGER.debug("Response %s", result)
response = result.choices[0].message.model_dump(include={"role", "content"})
messages.append(response)
self.history[conversation_id] = messages self.history[conversation_id] = messages
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["content"]) intent_response.async_set_speech(response.content)
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
def _async_generate_prompt(self, raw_prompt: str) -> str: def _async_generate_prompt(
self,
raw_prompt: str,
llm_api: llm.API | None,
) -> str:
"""Generate a prompt for the user.""" """Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED
return template.Template(raw_prompt, self.hass).async_render( return template.Template(raw_prompt, self.hass).async_render(
{ {
"ha_name": self.hass.config.location_name, "ha_name": self.hass.config.location_name,

View File

@ -1,12 +1,12 @@
{ {
"domain": "openai_conversation", "domain": "openai_conversation",
"name": "OpenAI Conversation", "name": "OpenAI Conversation",
"after_dependencies": ["assist_pipeline"], "after_dependencies": ["assist_pipeline", "intent"],
"codeowners": ["@balloob"], "codeowners": ["@balloob"],
"config_flow": true, "config_flow": true,
"dependencies": ["conversation"], "dependencies": ["conversation"],
"documentation": "https://www.home-assistant.io/integrations/openai_conversation", "documentation": "https://www.home-assistant.io/integrations/openai_conversation",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"requirements": ["openai==1.3.8"] "requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"]
} }

View File

@ -18,10 +18,11 @@
"init": { "init": {
"data": { "data": {
"prompt": "Prompt Template", "prompt": "Prompt Template",
"model": "Completion Model", "chat_model": "[%key:common::generic::model%]",
"max_tokens": "Maximum tokens to return in response", "max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P" "top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
} }
} }
} }

View File

@ -2826,7 +2826,8 @@ voip-utils==0.1.0
volkszaehler==0.4.0 volkszaehler==0.4.0
# homeassistant.components.google_generative_ai_conversation # homeassistant.components.google_generative_ai_conversation
voluptuous-openapi==0.0.3 # homeassistant.components.openai_conversation
voluptuous-openapi==0.0.4
# homeassistant.components.volvooncall # homeassistant.components.volvooncall
volvooncall==0.10.3 volvooncall==0.10.3

View File

@ -2191,7 +2191,8 @@ vilfo-api-client==0.5.0
voip-utils==0.1.0 voip-utils==0.1.0
# homeassistant.components.google_generative_ai_conversation # homeassistant.components.google_generative_ai_conversation
voluptuous-openapi==0.0.3 # homeassistant.components.openai_conversation
voluptuous-openapi==0.0.4
# homeassistant.components.volvooncall # homeassistant.components.volvooncall
volvooncall==0.10.3 volvooncall==0.10.3

View File

@ -4,7 +4,9 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -24,6 +26,15 @@ def mock_config_entry(hass):
return entry return entry
@pytest.fixture
def mock_config_entry_with_assist(hass, mock_config_entry):
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
)
return mock_config_entry
@pytest.fixture @pytest.fixture
async def mock_init_component(hass, mock_config_entry): async def mock_init_component(hass, mock_config_entry):
"""Initialize integration.""" """Initialize integration."""

View File

@ -16,9 +16,7 @@
- Test Device 4 - Test Device 4
- 1 (3) - 1 (3)
Answer the user's questions about the world truthfully. If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''', ''',
'role': 'system', 'role': 'system',
}), }),
@ -26,10 +24,119 @@
'content': 'hello', 'content': 'hello',
'role': 'user', 'role': 'user',
}), }),
ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None),
])
# ---
# name: test_default_prompt[config_entry_options0-None]
list([
dict({ dict({
'content': 'Hello, how can I help you?', 'content': '''
'role': 'assistant', This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'system',
}), }),
dict({
'content': 'hello',
'role': 'user',
}),
ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None),
])
# ---
# name: test_default_prompt[config_entry_options0-conversation.openai]
list([
dict({
'content': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None),
])
# ---
# name: test_default_prompt[config_entry_options1-None]
list([
dict({
'content': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None),
])
# ---
# name: test_default_prompt[config_entry_options1-conversation.openai]
list([
dict({
'content': '''
This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Test Area:
- Test Device (Test Model)
Test Area 2:
- Test Device 2
- Test Device 3 (Test Model 3A)
- Test Device 4
- 1 (3)
Call the intent tools to control the system. Just pass the name to the intent.
''',
'role': 'system',
}),
dict({
'content': 'hello',
'role': 'user',
}),
ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None),
]) ])
# --- # ---
# name: test_default_prompt[conversation.openai] # name: test_default_prompt[conversation.openai]
@ -49,9 +156,7 @@
- Test Device 4 - Test Device 4
- 1 (3) - 1 (3)
Answer the user's questions about the world truthfully. If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
''', ''',
'role': 'system', 'role': 'system',
}), }),
@ -59,9 +164,39 @@
'content': 'hello', 'content': 'hello',
'role': 'user', 'role': 'user',
}), }),
dict({ ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None),
'content': 'Hello, how can I help you?',
'role': 'assistant',
}),
]) ])
# --- # ---
# name: test_unknown_hass_api
dict({
'conversation_id': None,
'response': IntentResponse(
card=dict({
}),
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
failed_results=list([
]),
intent=None,
intent_targets=list([
]),
language='en',
matched_states=list([
]),
reprompt=dict({
}),
response_type=<IntentResponseType.ERROR: 'error'>,
speech=dict({
'plain': dict({
'extra_data': None,
'speech': 'Error preparing LLM API: API non-existing not found',
}),
}),
speech_slots=dict({
}),
success_results=list([
]),
unmatched_states=list([
]),
),
})
# ---

View File

@ -6,18 +6,34 @@ from httpx import Response
from openai import RateLimitError from openai import RateLimitError
from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from openai.types.completion_usage import CompletionUsage from openai.types.completion_usage import CompletionUsage
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
intent,
llm,
)
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.mark.parametrize("agent_id", [None, "conversation.openai"]) @pytest.mark.parametrize("agent_id", [None, "conversation.openai"])
@pytest.mark.parametrize(
"config_entry_options", [{}, {CONF_LLM_HASS_API: llm.LLM_API_ASSIST}]
)
async def test_default_prompt( async def test_default_prompt(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -26,6 +42,7 @@ async def test_default_prompt(
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
agent_id: str, agent_id: str,
config_entry_options: dict,
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
@ -36,6 +53,14 @@ async def test_default_prompt(
if agent_id is None: if agent_id is None:
agent_id = mock_config_entry.entry_id agent_id = mock_config_entry.entry_id
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "1234")}, connections={("test", "1234")},
@ -194,3 +219,312 @@ async def test_conversation_agent(
mock_config_entry.entry_id mock_config_entry.entry_id
) )
assert agent.supported_languages == "*" assert agent.supported_languages == "*"
@patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
)
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test function call from the assistant."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
role = message["role"] if isinstance(message, dict) else message.role
if role == "tool":
return ChatCompletion(
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="I have successfully called the function",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
return ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
function=Function(
arguments='{"param1":"test_value"}',
name="test_tool",
),
type="function",
)
],
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=completion_result,
) as mock_create:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[1][2]["messages"][3] == {
"role": "tool",
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
"name": "test_tool",
"content": '"Test response"',
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "test_value"},
platform="openai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
),
)
@patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
)
async def test_function_exception(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test function call with exception."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
role = message["role"] if isinstance(message, dict) else message.role
if role == "tool":
return ChatCompletion(
id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="There was an error calling the function",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
return ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
role="assistant",
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="call_AbCdEfGhIjKlMnOpQrStUvWx",
function=Function(
arguments='{"param1":"test_value"}',
name="test_tool",
),
type="function",
)
],
),
)
],
created=1700000000,
model="gpt-4-1106-preview",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
)
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=completion_result,
) as mock_create:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert mock_create.mock_calls[1][2]["messages"][3] == {
"role": "tool",
"tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx",
"name": "test_tool",
"content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}',
}
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "test_value"},
platform="openai_conversation",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
),
)
async def test_assist_api_tools_conversion(
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test that we are able to convert actual tools from Assist API."""
for component in [
"intent",
"todo",
"light",
"shopping_list",
"humidifier",
"climate",
"media_player",
"vacuum",
"cover",
"weather",
]:
assert await async_setup_component(hass, component, {})
agent_id = mock_config_entry_with_assist.entry_id
with patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
return_value=ChatCompletion(
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content="Hello, how can I help you?",
role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-3.5-turbo-0613",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
),
) as mock_create:
await conversation.async_converse(hass, "hello", None, None, agent_id=agent_id)
tools = mock_create.mock_calls[0][2]["tools"]
assert tools
async def test_unknown_hass_api(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
snapshot: SnapshotAssertion,
mock_init_component,
) -> None:
"""Test when we reference an API that no longer exists."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: "non-existing",
},
)
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result == snapshot