Bump openai end switch from dall-e-2 to dall-e-3 (#104998)

* Bump openai

* Fix tests

* Apply suggestions from code review

* Undo conftest changes

* Raise repasir issue

* Explicitly use async mock for chat.completions.create

It is not always detected correctly as async because it uses a decorator

* removed duplicated message

* ruff

* Compatibility with old pydantic versions

* Compatibility with old pydantic versions

* More tests

* Apply suggestions from code review

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Apply suggestions from code review

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Denis Shulyaka 2023-12-11 17:47:26 +03:00 committed by GitHub
parent c0314cd05c
commit 1242456ff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 269 additions and 71 deletions

View File

@ -1,12 +1,10 @@
"""The OpenAI Conversation integration.""" """The OpenAI Conversation integration."""
from __future__ import annotations from __future__ import annotations
from functools import partial
import logging import logging
from typing import Literal from typing import Literal
import openai import openai
from openai import error
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
@ -23,7 +21,13 @@ from homeassistant.exceptions import (
HomeAssistantError, HomeAssistantError,
TemplateError, TemplateError,
) )
from homeassistant.helpers import config_validation as cv, intent, selector, template from homeassistant.helpers import (
config_validation as cv,
intent,
issue_registry as ir,
selector,
template,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import ulid from homeassistant.util import ulid
@ -52,17 +56,38 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def render_image(call: ServiceCall) -> ServiceResponse: async def render_image(call: ServiceCall) -> ServiceResponse:
"""Render an image with dall-e.""" """Render an image with dall-e."""
try: client = hass.data[DOMAIN][call.data["config_entry"]]
response = await openai.Image.acreate(
api_key=hass.data[DOMAIN][call.data["config_entry"]], if call.data["size"] in ("256", "512", "1024"):
prompt=call.data["prompt"], ir.async_create_issue(
n=1, hass,
size=f'{call.data["size"]}x{call.data["size"]}', DOMAIN,
"image_size_deprecated_format",
breaks_in_ha_version="2024.7.0",
is_fixable=False,
is_persistent=True,
learn_more_url="https://www.home-assistant.io/integrations/openai_conversation/",
severity=ir.IssueSeverity.WARNING,
translation_key="image_size_deprecated_format",
) )
except error.OpenAIError as err: size = "1024x1024"
else:
size = call.data["size"]
try:
response = await client.images.generate(
model="dall-e-3",
prompt=call.data["prompt"],
size=size,
quality=call.data["quality"],
style=call.data["style"],
response_format="url",
n=1,
)
except openai.OpenAIError as err:
raise HomeAssistantError(f"Error generating image: {err}") from err raise HomeAssistantError(f"Error generating image: {err}") from err
return response["data"][0] return response.data[0].model_dump(exclude={"b64_json"})
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
@ -76,7 +101,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
} }
), ),
vol.Required("prompt"): cv.string, vol.Required("prompt"): cv.string,
vol.Optional("size", default="512"): vol.In(("256", "512", "1024")), vol.Optional("size", default="1024x1024"): vol.In(
("1024x1024", "1024x1792", "1792x1024", "256", "512", "1024")
),
vol.Optional("quality", default="standard"): vol.In(("standard", "hd")),
vol.Optional("style", default="vivid"): vol.In(("vivid", "natural")),
} }
), ),
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,
@ -86,21 +115,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up OpenAI Conversation from a config entry.""" """Set up OpenAI Conversation from a config entry."""
client = openai.AsyncOpenAI(api_key=entry.data[CONF_API_KEY])
try: try:
await hass.async_add_executor_job( await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
partial( except openai.AuthenticationError as err:
openai.Model.list,
api_key=entry.data[CONF_API_KEY],
request_timeout=10,
)
)
except error.AuthenticationError as err:
_LOGGER.error("Invalid API key: %s", err) _LOGGER.error("Invalid API key: %s", err)
return False return False
except error.OpenAIError as err: except openai.OpenAIError as err:
raise ConfigEntryNotReady(err) from err raise ConfigEntryNotReady(err) from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = entry.data[CONF_API_KEY] hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry)) conversation.async_set_agent(hass, entry, OpenAIAgent(hass, entry))
return True return True
@ -160,9 +184,10 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
_LOGGER.debug("Prompt for %s: %s", model, messages) _LOGGER.debug("Prompt for %s: %s", model, messages)
client = self.hass.data[DOMAIN][self.entry.entry_id]
try: try:
result = await openai.ChatCompletion.acreate( result = await client.chat.completions.create(
api_key=self.entry.data[CONF_API_KEY],
model=model, model=model,
messages=messages, messages=messages,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -170,7 +195,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
temperature=temperature, temperature=temperature,
user=conversation_id, user=conversation_id,
) )
except error.OpenAIError as err: except openai.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error( intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
@ -181,7 +206,7 @@ class OpenAIAgent(conversation.AbstractConversationAgent):
) )
_LOGGER.debug("Response %s", result) _LOGGER.debug("Response %s", result)
response = result["choices"][0]["message"] response = result.choices[0].message.model_dump(include={"role", "content"})
messages.append(response) messages.append(response)
self.history[conversation_id] = messages self.history[conversation_id] = messages

View File

@ -1,14 +1,12 @@
"""Config flow for OpenAI Conversation integration.""" """Config flow for OpenAI Conversation integration."""
from __future__ import annotations from __future__ import annotations
from functools import partial
import logging import logging
import types import types
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import Any
import openai import openai
from openai import error
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
@ -59,8 +57,8 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user. Data has the keys from STEP_USER_DATA_SCHEMA with values provided by the user.
""" """
openai.api_key = data[CONF_API_KEY] client = openai.AsyncOpenAI(api_key=data[CONF_API_KEY])
await hass.async_add_executor_job(partial(openai.Model.list, request_timeout=10)) await hass.async_add_executor_job(client.with_options(timeout=10.0).models.list)
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
@ -81,9 +79,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
try: try:
await validate_input(self.hass, user_input) await validate_input(self.hass, user_input)
except error.APIConnectionError: except openai.APIConnectionError:
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
except error.AuthenticationError: except openai.AuthenticationError:
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception") _LOGGER.exception("Unexpected exception")

View File

@ -7,5 +7,5 @@
"documentation": "https://www.home-assistant.io/integrations/openai_conversation", "documentation": "https://www.home-assistant.io/integrations/openai_conversation",
"integration_type": "service", "integration_type": "service",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"requirements": ["openai==0.27.2"] "requirements": ["openai==1.3.8"]
} }

View File

@ -11,12 +11,30 @@ generate_image:
text: text:
multiline: true multiline: true
size: size:
required: true required: false
example: "512" example: "1024x1024"
default: "512" default: "1024x1024"
selector: selector:
select: select:
options: options:
- "256" - "1024x1024"
- "512" - "1024x1792"
- "1024" - "1792x1024"
quality:
required: false
example: "standard"
default: "standard"
selector:
select:
options:
- "standard"
- "hd"
style:
required: false
example: "vivid"
default: "vivid"
selector:
select:
options:
- "vivid"
- "natural"

View File

@ -43,8 +43,22 @@
"size": { "size": {
"name": "Size", "name": "Size",
"description": "The size of the image to generate" "description": "The size of the image to generate"
},
"quality": {
"name": "Quality",
"description": "The quality of the image that will be generated"
},
"style": {
"name": "Style",
"description": "The style of the generated image"
} }
} }
} }
},
"issues": {
"image_size_deprecated_format": {
"title": "Deprecated size format for image generation service",
"description": "OpenAI is now using Dall-E 3 to generate images when calling `openai_conversation.generate_image`, which supports different sizes. Valid values are now \"1024x1024\", \"1024x1792\", \"1792x1024\". The old values of \"256\", \"512\", \"1024\" are currently interpreted as \"1024x1024\".\nPlease update your scripts or automations with the new parameters."
}
} }
} }

View File

@ -1393,7 +1393,7 @@ open-garage==0.2.0
open-meteo==0.3.1 open-meteo==0.3.1
# homeassistant.components.openai_conversation # homeassistant.components.openai_conversation
openai==0.27.2 openai==1.3.8
# homeassistant.components.opencv # homeassistant.components.opencv
# opencv-python-headless==4.6.0.66 # opencv-python-headless==4.6.0.66

View File

@ -1087,7 +1087,7 @@ open-garage==0.2.0
open-meteo==0.3.1 open-meteo==0.3.1
# homeassistant.components.openai_conversation # homeassistant.components.openai_conversation
openai==0.27.2 openai==1.3.8
# homeassistant.components.openerz # homeassistant.components.openerz
openerz-api==0.2.0 openerz-api==0.2.0

View File

@ -25,7 +25,7 @@ def mock_config_entry(hass):
async def mock_init_component(hass, mock_config_entry): async def mock_init_component(hass, mock_config_entry):
"""Initialize integration.""" """Initialize integration."""
with patch( with patch(
"openai.Model.list", "openai.resources.models.AsyncModels.list",
): ):
assert await async_setup_component(hass, "openai_conversation", {}) assert await async_setup_component(hass, "openai_conversation", {})
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -1,7 +1,8 @@
"""Test the OpenAI Conversation config flow.""" """Test the OpenAI Conversation config flow."""
from unittest.mock import patch from unittest.mock import patch
from openai.error import APIConnectionError, AuthenticationError, InvalidRequestError from httpx import Response
from openai import APIConnectionError, AuthenticationError, BadRequestError
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
@ -32,7 +33,7 @@ async def test_form(hass: HomeAssistant) -> None:
assert result["errors"] is None assert result["errors"] is None
with patch( with patch(
"homeassistant.components.openai_conversation.config_flow.openai.Model.list", "homeassistant.components.openai_conversation.config_flow.openai.resources.models.AsyncModels.list",
), patch( ), patch(
"homeassistant.components.openai_conversation.async_setup_entry", "homeassistant.components.openai_conversation.async_setup_entry",
return_value=True, return_value=True,
@ -76,9 +77,19 @@ async def test_options(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
[ [
(APIConnectionError(""), "cannot_connect"), (APIConnectionError(request=None), "cannot_connect"),
(AuthenticationError, "invalid_auth"), (
(InvalidRequestError, "unknown"), AuthenticationError(
response=Response(status_code=None, request=""), body=None, message=None
),
"invalid_auth",
),
(
BadRequestError(
response=Response(status_code=None, request=""), body=None, message=None
),
"unknown",
),
], ],
) )
async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> None: async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> None:
@ -88,7 +99,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
) )
with patch( with patch(
"homeassistant.components.openai_conversation.config_flow.openai.Model.list", "homeassistant.components.openai_conversation.config_flow.openai.resources.models.AsyncModels.list",
side_effect=side_effect, side_effect=side_effect,
): ):
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(

View File

@ -1,7 +1,18 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from unittest.mock import patch from unittest.mock import AsyncMock, patch
from openai import error from httpx import Response
from openai import (
APIConnectionError,
AuthenticationError,
BadRequestError,
RateLimitError,
)
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
from openai.types.image import Image
from openai.types.images_response import ImagesResponse
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -9,6 +20,7 @@ from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry as ar, device_registry as dr, intent from homeassistant.helpers import area_registry as ar, device_registry as dr, intent
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -94,17 +106,30 @@ async def test_default_prompt(
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
with patch( with patch(
"openai.ChatCompletion.acreate", "openai.resources.chat.completions.AsyncCompletions.create",
return_value={ new_callable=AsyncMock,
"choices": [ return_value=ChatCompletion(
{ id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
"message": { choices=[
"role": "assistant", Choice(
"content": "Hello, how can I help you?", finish_reason="stop",
} index=0,
} message=ChatCompletionMessage(
] content="Hello, how can I help you?",
}, role="assistant",
function_call=None,
tool_calls=None,
),
)
],
created=1700000000,
model="gpt-3.5-turbo-0613",
object="chat.completion",
system_fingerprint=None,
usage=CompletionUsage(
completion_tokens=9, prompt_tokens=8, total_tokens=17
),
),
) as mock_create: ) as mock_create:
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
@ -119,7 +144,11 @@ async def test_error_handling(
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
with patch( with patch(
"openai.ChatCompletion.acreate", side_effect=error.ServiceUnavailableError "openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
side_effect=RateLimitError(
response=Response(status_code=None, request=""), body=None, message=None
),
): ):
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
@ -140,8 +169,11 @@ async def test_template_error(
}, },
) )
with patch( with patch(
"openai.Model.list", "openai.resources.models.AsyncModels.list",
), patch("openai.ChatCompletion.acreate"): ), patch(
"openai.resources.chat.completions.AsyncCompletions.create",
new_callable=AsyncMock,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
result = await conversation.async_converse( result = await conversation.async_converse(
@ -169,15 +201,67 @@ async def test_conversation_agent(
[ [
( (
{"prompt": "Picture of a dog"}, {"prompt": "Picture of a dog"},
{"prompt": "Picture of a dog", "size": "512x512"}, {
"prompt": "Picture of a dog",
"size": "1024x1024",
"quality": "standard",
"style": "vivid",
},
),
(
{
"prompt": "Picture of a dog",
"size": "1024x1792",
"quality": "hd",
"style": "vivid",
},
{
"prompt": "Picture of a dog",
"size": "1024x1792",
"quality": "hd",
"style": "vivid",
},
),
(
{
"prompt": "Picture of a dog",
"size": "1792x1024",
"quality": "standard",
"style": "natural",
},
{
"prompt": "Picture of a dog",
"size": "1792x1024",
"quality": "standard",
"style": "natural",
},
), ),
( (
{"prompt": "Picture of a dog", "size": "256"}, {"prompt": "Picture of a dog", "size": "256"},
{"prompt": "Picture of a dog", "size": "256x256"}, {
"prompt": "Picture of a dog",
"size": "1024x1024",
"quality": "standard",
"style": "vivid",
},
),
(
{"prompt": "Picture of a dog", "size": "512"},
{
"prompt": "Picture of a dog",
"size": "1024x1024",
"quality": "standard",
"style": "vivid",
},
), ),
( (
{"prompt": "Picture of a dog", "size": "1024"}, {"prompt": "Picture of a dog", "size": "1024"},
{"prompt": "Picture of a dog", "size": "1024x1024"}, {
"prompt": "Picture of a dog",
"size": "1024x1024",
"quality": "standard",
"style": "vivid",
},
), ),
], ],
) )
@ -190,11 +274,22 @@ async def test_generate_image_service(
) -> None: ) -> None:
"""Test generate image service.""" """Test generate image service."""
service_data["config_entry"] = mock_config_entry.entry_id service_data["config_entry"] = mock_config_entry.entry_id
expected_args["api_key"] = mock_config_entry.data["api_key"] expected_args["model"] = "dall-e-3"
expected_args["response_format"] = "url"
expected_args["n"] = 1 expected_args["n"] = 1
with patch( with patch(
"openai.Image.acreate", return_value={"data": [{"url": "A"}]} "openai.resources.images.AsyncImages.generate",
return_value=ImagesResponse(
created=1700000000,
data=[
Image(
b64_json=None,
revised_prompt="A clear and detailed picture of an ordinary canine",
url="A",
)
],
),
) as mock_create: ) as mock_create:
response = await hass.services.async_call( response = await hass.services.async_call(
"openai_conversation", "openai_conversation",
@ -204,7 +299,10 @@ async def test_generate_image_service(
return_response=True, return_response=True,
) )
assert response == {"url": "A"} assert response == {
"url": "A",
"revised_prompt": "A clear and detailed picture of an ordinary canine",
}
assert len(mock_create.mock_calls) == 1 assert len(mock_create.mock_calls) == 1
assert mock_create.mock_calls[0][2] == expected_args assert mock_create.mock_calls[0][2] == expected_args
@ -216,7 +314,10 @@ async def test_generate_image_service_error(
) -> None: ) -> None:
"""Test generate image service handles errors.""" """Test generate image service handles errors."""
with patch( with patch(
"openai.Image.acreate", side_effect=error.ServiceUnavailableError("Reason") "openai.resources.images.AsyncImages.generate",
side_effect=RateLimitError(
response=Response(status_code=None, request=""), body=None, message="Reason"
),
), pytest.raises(HomeAssistantError, match="Error generating image: Reason"): ), pytest.raises(HomeAssistantError, match="Error generating image: Reason"):
await hass.services.async_call( await hass.services.async_call(
"openai_conversation", "openai_conversation",
@ -228,3 +329,34 @@ async def test_generate_image_service_error(
blocking=True, blocking=True,
return_response=True, return_response=True,
) )
@pytest.mark.parametrize(
("side_effect", "error"),
[
(APIConnectionError(request=None), "Connection error"),
(
AuthenticationError(
response=Response(status_code=None, request=""), body=None, message=None
),
"Invalid API key",
),
(
BadRequestError(
response=Response(status_code=None, request=""), body=None, message=None
),
"openai_conversation integration not ready yet: None",
),
],
)
async def test_init_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, caplog, side_effect, error
) -> None:
"""Test initialization errors."""
with patch(
"openai.resources.models.AsyncModels.list",
side_effect=side_effect,
):
assert await async_setup_component(hass, "openai_conversation", {})
await hass.async_block_till_done()
assert error in caplog.text