Add configuration options to OpenAI integration (#86768)

* Added multiple features to OpenAI integration

* Fixed failed test

* Removed features and improved tests

* initiated component before starting options flow
This commit is contained in:
Ben Dews 2023-01-31 00:24:11 +11:00 committed by GitHub
parent 032a37b121
commit 21d1c647c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 185 additions and 19 deletions

View File

@ -15,7 +15,18 @@ from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
from homeassistant.helpers import area_registry, intent, template from homeassistant.helpers import area_registry, intent, template
from homeassistant.util import ulid from homeassistant.util import ulid
from .const import DEFAULT_MODEL, DEFAULT_PROMPT from .const import (
CONF_MAX_TOKENS,
CONF_MODEL,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_MAX_TOKENS,
DEFAULT_MODEL,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -63,7 +74,11 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
model = DEFAULT_MODEL raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history: if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
@ -71,7 +86,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
else: else:
conversation_id = ulid.ulid() conversation_id = ulid.ulid()
try: try:
prompt = self._async_generate_prompt() prompt = self._async_generate_prompt(raw_prompt)
except TemplateError as err: except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err) _LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
@ -98,15 +113,14 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
_LOGGER.debug("Prompt for %s: %s", model, prompt) _LOGGER.debug("Prompt for %s: %s", model, prompt)
try: try:
result = await self.hass.async_add_executor_job( result = await openai.Completion.acreate(
partial(
openai.Completion.create,
engine=model, engine=model,
prompt=prompt, prompt=prompt,
max_tokens=150, max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=conversation_id, user=conversation_id,
) )
)
except error.OpenAIError as err: except error.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error( intent_response.async_set_error(
@ -131,9 +145,9 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id
) )
def _async_generate_prompt(self) -> str: def _async_generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user.""" """Generate a prompt for the user."""
return template.Template(DEFAULT_PROMPT, self.hass).async_render( return template.Template(raw_prompt, self.hass).async_render(
{ {
"ha_name": self.hass.config.location_name, "ha_name": self.hass.config.location_name,
"areas": list(area_registry.async_get(self.hass).areas.values()), "areas": list(area_registry.async_get(self.hass).areas.values()),

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from functools import partial from functools import partial
import logging import logging
import types
from types import MappingProxyType
from typing import Any from typing import Any
import openai import openai
@ -13,8 +15,26 @@ from homeassistant import config_entries
from homeassistant.const import CONF_API_KEY from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
TextSelector,
TextSelectorConfig,
)
from .const import DOMAIN from .const import (
CONF_MAX_TOKENS,
CONF_MODEL,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_MAX_TOKENS,
DEFAULT_MODEL,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -24,6 +44,16 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
DEFAULT_OPTIONS = types.MappingProxyType(
{
CONF_PROMPT: DEFAULT_PROMPT,
CONF_MODEL: DEFAULT_MODEL,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
}
)
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
@ -68,3 +98,49 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
) )
@staticmethod
def async_get_options_flow(
config_entry: config_entries.ConfigEntry,
) -> config_entries.OptionsFlow:
"""Create the options flow."""
return OptionsFlow(config_entry)
class OptionsFlow(config_entries.OptionsFlow):
"""OpenAI config flow options handler."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage the options."""
if user_input is not None:
return self.async_create_entry(title="OpenAI Conversation", data=user_input)
schema = openai_config_option_schema(self.config_entry.options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)
def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
"""Return a schema for OpenAI completion options."""
if not options:
options = DEFAULT_OPTIONS
return {
vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TextSelector(
TextSelectorConfig(multiline=True)
),
vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str,
vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int,
vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector(
NumberSelectorConfig(min=0, max=1, step=0.05)
),
vol.Required(
CONF_TEMPERATURE, default=options.get(CONF_TEMPERATURE)
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
}

View File

@ -2,7 +2,6 @@
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_MODEL = "text-davinci-003"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home: An overview of the areas and the devices in this smart home:
@ -28,3 +27,11 @@ Now finish this conversation:
Smart home: How can I assist? Smart home: How can I assist?
""" """
CONF_MODEL = "model"
DEFAULT_MODEL = "text-davinci-003"
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1
CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 0.5

View File

@ -15,5 +15,18 @@
"abort": { "abort": {
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]" "single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
} }
},
"options": {
"step": {
"init": {
"data": {
"prompt": "Prompt Template",
"model": "Completion Model",
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P"
}
}
}
} }
} }

View File

@ -15,5 +15,18 @@
} }
} }
} }
},
"options": {
"step": {
"init": {
"data": {
"prompt": "Prompt Template",
"model": "Completion Model",
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P"
}
}
}
} }
} }

View File

@ -5,7 +5,11 @@ from openai.error import APIConnectionError, AuthenticationError, InvalidRequest
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.openai_conversation.const import DOMAIN from homeassistant.components.openai_conversation.const import (
CONF_MODEL,
DEFAULT_MODEL,
DOMAIN,
)
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@ -50,6 +54,27 @@ async def test_form(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_options(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options form."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
"prompt": "Speak like a pirate",
"max_tokens": 200,
},
)
await hass.async_block_till_done()
assert options["type"] == FlowResultType.CREATE_ENTRY
assert options["data"]["prompt"] == "Speak like a pirate"
assert options["data"]["max_tokens"] == 200
assert options["data"][CONF_MODEL] == DEFAULT_MODEL
@pytest.mark.parametrize( @pytest.mark.parametrize(
"side_effect, error", "side_effect, error",
[ [

View File

@ -68,11 +68,10 @@ async def test_default_prompt(hass, mock_init_component):
device.id, disabled_by=device_registry.DeviceEntryDisabler.USER device.id, disabled_by=device_registry.DeviceEntryDisabler.USER
) )
with patch("openai.Completion.create") as mock_create: with patch("openai.Completion.acreate") as mock_create:
result = await conversation.async_converse(hass, "hello", None, Context()) result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert ( assert (
mock_create.mock_calls[0][2]["prompt"] mock_create.mock_calls[0][2]["prompt"]
== """This smart home is controlled by Home Assistant. == """This smart home is controlled by Home Assistant.
@ -101,7 +100,26 @@ Smart home: """
async def test_error_handling(hass, mock_init_component): async def test_error_handling(hass, mock_init_component):
"""Test that the default prompt works.""" """Test that the default prompt works."""
with patch("openai.Completion.create", side_effect=error.ServiceUnavailableError): with patch("openai.Completion.acreate", side_effect=error.ServiceUnavailableError):
result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_template_error(hass, mock_config_entry, mock_init_component):
"""Test that template error handling works."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
await hass.async_block_till_done()
with patch("openai.Completion.acreate"):
result = await conversation.async_converse(hass, "hello", None, Context()) result = await conversation.async_converse(hass, "hello", None, Context())
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result