From 885be98f8fcc279840cbdb6d5850231d2076e328 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 28 Mar 2023 23:37:43 -0400 Subject: [PATCH] OpenAI to use GPT3.5 (#90423) * OpenAI to use GPT3.5 * Add snapshot --- .../openai_conversation/__init__.py | 42 ++++++---------- .../openai_conversation/config_flow.py | 41 +++++++++++----- .../components/openai_conversation/const.py | 8 +--- .../openai_conversation/manifest.json | 2 +- requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- .../snapshots/test_init.ambr | 34 +++++++++++++ .../openai_conversation/test_config_flow.py | 6 +-- .../openai_conversation/test_init.py | 48 ++++++++----------- 9 files changed, 107 insertions(+), 78 deletions(-) create mode 100644 tests/components/openai_conversation/snapshots/test_init.ambr diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index 355b7764b08..3e67d4e27da 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -16,13 +16,13 @@ from homeassistant.helpers import area_registry as ar, intent, template from homeassistant.util import ulid from .const import ( + CONF_CHAT_MODEL, CONF_MAX_TOKENS, - CONF_MODEL, CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_P, + DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, - DEFAULT_MODEL, DEFAULT_PROMPT, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, @@ -63,7 +63,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent): """Initialize the agent.""" self.hass = hass self.entry = entry - self.history: dict[str, str] = {} + self.history: dict[str, list[dict]] = {} @property def attribution(self): @@ -75,14 +75,14 @@ class OpenAIAgent(conversation.AbstractConversationAgent): ) -> conversation.ConversationResult: """Process a sentence.""" raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL) + model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id - prompt = self.history[conversation_id] + messages = self.history[conversation_id] else: conversation_id = ulid.ulid() try: @@ -97,25 +97,16 @@ class OpenAIAgent(conversation.AbstractConversationAgent): return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) + messages = [{"role": "system", "content": prompt}] - user_name = "User" - if ( - user_input.context.user_id - and ( - user := await self.hass.auth.async_get_user(user_input.context.user_id) - ) - and user.name - ): - user_name = user.name + messages.append({"role": "user", "content": user_input.text}) - prompt += f"\n{user_name}: {user_input.text}\nSmart home: " - - _LOGGER.debug("Prompt for %s: %s", model, prompt) + _LOGGER.debug("Prompt for %s: %s", model, messages) try: - result = await openai.Completion.acreate( - engine=model, - prompt=prompt, + result = await openai.ChatCompletion.acreate( + model=model, + messages=messages, max_tokens=max_tokens, top_p=top_p, temperature=temperature, @@ -132,15 +123,12 @@ class OpenAIAgent(conversation.AbstractConversationAgent): ) _LOGGER.debug("Response %s", result) - response = result["choices"][0]["text"].strip() - self.history[conversation_id] = prompt + response - - stripped_response = response - if response.startswith("Smart home:"): - stripped_response = response[11:].strip() + response = result["choices"][0]["message"] + messages.append(response) + self.history[conversation_id] = messages intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(stripped_response) + intent_response.async_set_speech(response["content"]) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index 2db5e98a1f4..892d794bcaf 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -22,13 +22,13 @@ from homeassistant.helpers.selector import ( ) from .const import ( + CONF_CHAT_MODEL, CONF_MAX_TOKENS, - CONF_MODEL, CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_P, + DEFAULT_CHAT_MODEL, DEFAULT_MAX_TOKENS, - DEFAULT_MODEL, DEFAULT_PROMPT, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, @@ -46,7 +46,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema( DEFAULT_OPTIONS = types.MappingProxyType( { CONF_PROMPT: DEFAULT_PROMPT, - CONF_MODEL: DEFAULT_MODEL, + CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL, CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS, CONF_TOP_P: DEFAULT_TOP_P, CONF_TEMPERATURE: DEFAULT_TEMPERATURE, @@ -131,13 +131,32 @@ def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict: if not options: options = DEFAULT_OPTIONS return { - vol.Required(CONF_PROMPT, default=options.get(CONF_PROMPT)): TemplateSelector(), - vol.Required(CONF_MODEL, default=options.get(CONF_MODEL)): str, - vol.Required(CONF_MAX_TOKENS, default=options.get(CONF_MAX_TOKENS)): int, - vol.Required(CONF_TOP_P, default=options.get(CONF_TOP_P)): NumberSelector( - NumberSelectorConfig(min=0, max=1, step=0.05) - ), - vol.Required( - CONF_TEMPERATURE, default=options.get(CONF_TEMPERATURE) + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options[CONF_PROMPT]}, + default=DEFAULT_PROMPT, + ): TemplateSelector(), + vol.Optional( + CONF_CHAT_MODEL, + description={ + # New key in HA 2023.4 + "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + }, + default=DEFAULT_CHAT_MODEL, + ): str, + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options[CONF_MAX_TOKENS]}, + default=DEFAULT_MAX_TOKENS, + ): int, + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options[CONF_TOP_P]}, + default=DEFAULT_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options[CONF_TEMPERATURE]}, + default=DEFAULT_TEMPERATURE, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), } diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index ed914efeb6e..88289eb90b0 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -22,13 +22,9 @@ An overview of the areas and the devices in this smart home: Answer the user's questions about the world truthfully. If the user wants to control a device, reject the request and suggest using the Home Assistant app. - -Now finish this conversation: - -Smart home: How can I assist? """ -CONF_MODEL = "model" -DEFAULT_MODEL = "text-davinci-003" +CONF_CHAT_MODEL = "chat_model" +DEFAULT_CHAT_MODEL = "gpt-3.5-turbo" CONF_MAX_TOKENS = "max_tokens" DEFAULT_MAX_TOKENS = 150 CONF_TOP_P = "top_p" diff --git a/homeassistant/components/openai_conversation/manifest.json b/homeassistant/components/openai_conversation/manifest.json index 0e245eb78b5..88d347355e9 100644 --- a/homeassistant/components/openai_conversation/manifest.json +++ b/homeassistant/components/openai_conversation/manifest.json @@ -7,5 +7,5 @@ "documentation": "https://www.home-assistant.io/integrations/openai_conversation", "integration_type": "service", "iot_class": "cloud_polling", - "requirements": ["openai==0.26.2"] + "requirements": ["openai==0.27.2"] } diff --git a/requirements_all.txt b/requirements_all.txt index 49ac9171eca..aeb7033d0a6 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1269,7 +1269,7 @@ open-garage==0.2.0 open-meteo==0.2.1 # homeassistant.components.openai_conversation -openai==0.26.2 +openai==0.27.2 # homeassistant.components.opencv # opencv-python-headless==4.6.0.66 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 4213338da8a..a21ef8b32b8 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -947,7 +947,7 @@ open-garage==0.2.0 open-meteo==0.2.1 # homeassistant.components.openai_conversation -openai==0.26.2 +openai==0.27.2 # homeassistant.components.openerz openerz-api==0.2.0 diff --git a/tests/components/openai_conversation/snapshots/test_init.ambr b/tests/components/openai_conversation/snapshots/test_init.ambr new file mode 100644 index 00000000000..bc06f51f416 --- /dev/null +++ b/tests/components/openai_conversation/snapshots/test_init.ambr @@ -0,0 +1,34 @@ +# serializer version: 1 +# name: test_default_prompt + list([ + dict({ + 'content': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Answer the user's questions about the world truthfully. + + If the user wants to control a device, reject the request and suggest using the Home Assistant app. + ''', + 'role': 'system', + }), + dict({ + 'content': 'hello', + 'role': 'user', + }), + dict({ + 'content': 'Hello, how can I help you?', + 'role': 'assistant', + }), + ]) +# --- diff --git a/tests/components/openai_conversation/test_config_flow.py b/tests/components/openai_conversation/test_config_flow.py index 25849882e82..4ce677d8cca 100644 --- a/tests/components/openai_conversation/test_config_flow.py +++ b/tests/components/openai_conversation/test_config_flow.py @@ -6,8 +6,8 @@ import pytest from homeassistant import config_entries from homeassistant.components.openai_conversation.const import ( - CONF_MODEL, - DEFAULT_MODEL, + CONF_CHAT_MODEL, + DEFAULT_CHAT_MODEL, DOMAIN, ) from homeassistant.core import HomeAssistant @@ -72,7 +72,7 @@ async def test_options( assert options["type"] == FlowResultType.CREATE_ENTRY assert options["data"]["prompt"] == "Speak like a pirate" assert options["data"]["max_tokens"] == 200 - assert options["data"][CONF_MODEL] == DEFAULT_MODEL + assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL @pytest.mark.parametrize( diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index 3b78a90f40e..144d77beab5 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -2,6 +2,7 @@ from unittest.mock import patch from openai import error +from syrupy.assertion import SnapshotAssertion from homeassistant.components import conversation from homeassistant.core import Context, HomeAssistant @@ -15,6 +16,7 @@ async def test_default_prompt( mock_init_component, area_registry: ar.AreaRegistry, device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, ) -> None: """Test that the default prompt works.""" for i in range(3): @@ -86,40 +88,30 @@ async def test_default_prompt( model=3, suggested_area="Test Area 2", ) - with patch("openai.Completion.acreate") as mock_create: + with patch( + "openai.ChatCompletion.acreate", + return_value={ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello, how can I help you?", + } + } + ] + }, + ) as mock_create: result = await conversation.async_converse(hass, "hello", None, Context()) assert result.response.response_type == intent.IntentResponseType.ACTION_DONE - assert ( - mock_create.mock_calls[0][2]["prompt"] - == """This smart home is controlled by Home Assistant. - -An overview of the areas and the devices in this smart home: - -Test Area: -- Test Device (Test Model) - -Test Area 2: -- Test Device 2 -- Test Device 3 (Test Model 3A) -- Test Device 4 -- 1 (3) - -Answer the user's questions about the world truthfully. - -If the user wants to control a device, reject the request and suggest using the Home Assistant app. - -Now finish this conversation: - -Smart home: How can I assist? -User: hello -Smart home: """ - ) + assert mock_create.mock_calls[0][2]["messages"] == snapshot async def test_error_handling(hass: HomeAssistant, mock_init_component) -> None: """Test that the default prompt works.""" - with patch("openai.Completion.acreate", side_effect=error.ServiceUnavailableError): + with patch( + "openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError + ): result = await conversation.async_converse(hass, "hello", None, Context()) assert result.response.response_type == intent.IntentResponseType.ERROR, result @@ -138,7 +130,7 @@ async def test_template_error( ) with patch( "openai.Engine.list", - ), patch("openai.Completion.acreate"): + ), patch("openai.ChatCompletion.acreate"): await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.async_block_till_done() result = await conversation.async_converse(hass, "hello", None, Context())