From 13b6cfa438441f258e6de1934e045f2611a9b220 Mon Sep 17 00:00:00 2001 From: Tim Laing <11019084+timlaing@users.noreply.github.com> Date: Sat, 15 Mar 2025 02:54:49 +0000 Subject: [PATCH] Add generate content service for OpenAI to match Google AI (#122818) * Aded Generate Content Service for OpenAI to match Google AI * Fixed code for commit checks * Addressed code review comments * Address review comments * Addressed @balloob review comments. * Address futher review comments from @balloob --- .../openai_conversation/__init__.py | 145 +++++++- .../components/openai_conversation/const.py | 22 +- .../components/openai_conversation/icons.json | 3 + .../openai_conversation/manifest.json | 2 +- .../openai_conversation/services.yaml | 20 ++ .../openai_conversation/strings.json | 18 + requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- .../openai_conversation/test_init.py | 314 +++++++++++++++++- 9 files changed, 500 insertions(+), 28 deletions(-) diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index 0fbda9b7f4a..d7fc5205f17 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -2,7 +2,26 @@ from __future__ import annotations +import base64 +from mimetypes import guess_file_type +from pathlib import Path + import openai +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_content_part_image_param import ( + ChatCompletionContentPartImageParam, + ImageURL, +) +from openai.types.chat.chat_completion_content_part_param import ( + ChatCompletionContentPartParam, +) +from openai.types.chat.chat_completion_content_part_text_param import ( + ChatCompletionContentPartTextParam, +) +from openai.types.chat.chat_completion_user_message_param import ( + ChatCompletionUserMessageParam, +) +from openai.types.images_response import ImagesResponse import voluptuous as vol from homeassistant.config_entries import ConfigEntry @@ -22,15 +41,33 @@ from homeassistant.helpers import config_validation as cv, selector from homeassistant.helpers.httpx_client import get_async_client from homeassistant.helpers.typing import ConfigType -from .const import DOMAIN, LOGGER +from .const import ( + CONF_CHAT_MODEL, + CONF_FILENAMES, + CONF_PROMPT, + DOMAIN, + LOGGER, + RECOMMENDED_CHAT_MODEL, +) SERVICE_GENERATE_IMAGE = "generate_image" +SERVICE_GENERATE_CONTENT = "generate_content" + PLATFORMS = (Platform.CONVERSATION,) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient] +def encode_file(file_path: str) -> tuple[str, str]: + """Return base64 version of file contents.""" + mime_type, _ = guess_file_type(file_path) + if mime_type is None: + mime_type = "application/octet-stream" + with open(file_path, "rb") as image_file: + return (mime_type, base64.b64encode(image_file.read()).decode("utf-8")) + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up OpenAI Conversation.""" @@ -49,9 +86,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: client: openai.AsyncClient = entry.runtime_data try: - response = await client.images.generate( + response: ImagesResponse = await client.images.generate( model="dall-e-3", - prompt=call.data["prompt"], + prompt=call.data[CONF_PROMPT], size=call.data["size"], quality=call.data["quality"], style=call.data["style"], @@ -63,6 +100,105 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return response.data[0].model_dump(exclude={"b64_json"}) + async def send_prompt(call: ServiceCall) -> ServiceResponse: + """Send a prompt to ChatGPT and return the response.""" + entry_id = call.data["config_entry"] + entry = hass.config_entries.async_get_entry(entry_id) + + if entry is None or entry.domain != DOMAIN: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="invalid_config_entry", + translation_placeholders={"config_entry": entry_id}, + ) + + model: str = entry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) + client: openai.AsyncClient = entry.runtime_data + + prompt_parts: list[ChatCompletionContentPartParam] = [ + ChatCompletionContentPartTextParam( + type="text", + text=call.data[CONF_PROMPT], + ) + ] + + def append_files_to_prompt() -> None: + for filename in call.data[CONF_FILENAMES]: + if not hass.config.is_allowed_path(filename): + raise HomeAssistantError( + f"Cannot read `{filename}`, no access to path; " + "`allowlist_external_dirs` may need to be adjusted in " + "`configuration.yaml`" + ) + if not Path(filename).exists(): + raise HomeAssistantError(f"`{filename}` does not exist") + mime_type, base64_file = encode_file(filename) + if "image/" not in mime_type: + raise HomeAssistantError( + "Only images are supported by the OpenAI API," + f"`{filename}` is not an image file" + ) + prompt_parts.append( + ChatCompletionContentPartImageParam( + type="image_url", + image_url=ImageURL( + url=f"data:{mime_type};base64,{base64_file}" + ), + ) + ) + + if CONF_FILENAMES in call.data: + await hass.async_add_executor_job(append_files_to_prompt) + + messages: list[ChatCompletionUserMessageParam] = [ + ChatCompletionUserMessageParam( + role="user", + content=prompt_parts, + ) + ] + + try: + response: ChatCompletion = await client.chat.completions.create( + model=model, + messages=messages, + n=1, + response_format={ + "type": "json_object", + }, + ) + + except openai.OpenAIError as err: + raise HomeAssistantError(f"Error generating content: {err}") from err + except FileNotFoundError as err: + raise HomeAssistantError(f"Error generating content: {err}") from err + + response_text: str = "" + for response_choice in response.choices: + if response_choice.message.content is not None: + response_text += response_choice.message.content.strip() + + return {"text": response_text} + + hass.services.async_register( + DOMAIN, + SERVICE_GENERATE_CONTENT, + send_prompt, + schema=vol.Schema( + { + vol.Required("config_entry"): selector.ConfigEntrySelector( + { + "integration": DOMAIN, + } + ), + vol.Required(CONF_PROMPT): cv.string, + vol.Optional(CONF_FILENAMES, default=[]): vol.All( + cv.ensure_list, [cv.string] + ), + } + ), + supports_response=SupportsResponse.ONLY, + ) + hass.services.async_register( DOMAIN, SERVICE_GENERATE_IMAGE, @@ -74,7 +210,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: "integration": DOMAIN, } ), - vol.Required("prompt"): cv.string, + vol.Required(CONF_PROMPT): cv.string, vol.Optional("size", default="1024x1024"): vol.In( ("1024x1024", "1024x1792", "1792x1024") ), @@ -84,6 +220,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ), supports_response=SupportsResponse.ONLY, ) + return True diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index 793e021e332..c9987cb81b9 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -3,22 +3,24 @@ import logging DOMAIN = "openai_conversation" -LOGGER = logging.getLogger(__package__) +LOGGER: logging.Logger = logging.getLogger(__package__) -CONF_RECOMMENDED = "recommended" -CONF_PROMPT = "prompt" CONF_CHAT_MODEL = "chat_model" -RECOMMENDED_CHAT_MODEL = "gpt-4o-mini" +CONF_FILENAMES = "filenames" CONF_MAX_TOKENS = "max_tokens" -RECOMMENDED_MAX_TOKENS = 150 -CONF_TOP_P = "top_p" -RECOMMENDED_TOP_P = 1.0 -CONF_TEMPERATURE = "temperature" -RECOMMENDED_TEMPERATURE = 1.0 +CONF_PROMPT = "prompt" +CONF_PROMPT = "prompt" CONF_REASONING_EFFORT = "reasoning_effort" +CONF_RECOMMENDED = "recommended" +CONF_TEMPERATURE = "temperature" +CONF_TOP_P = "top_p" +RECOMMENDED_CHAT_MODEL = "gpt-4o-mini" +RECOMMENDED_MAX_TOKENS = 150 RECOMMENDED_REASONING_EFFORT = "low" +RECOMMENDED_TEMPERATURE = 1.0 +RECOMMENDED_TOP_P = 1.0 -UNSUPPORTED_MODELS = [ +UNSUPPORTED_MODELS: list[str] = [ "o1-mini", "o1-mini-2024-09-12", "o1-preview", diff --git a/homeassistant/components/openai_conversation/icons.json b/homeassistant/components/openai_conversation/icons.json index 3abecd640d1..f0ece31c304 100644 --- a/homeassistant/components/openai_conversation/icons.json +++ b/homeassistant/components/openai_conversation/icons.json @@ -2,6 +2,9 @@ "services": { "generate_image": { "service": "mdi:image-sync" + }, + "generate_content": { + "service": "mdi:receipt-text" } } } diff --git a/homeassistant/components/openai_conversation/manifest.json b/homeassistant/components/openai_conversation/manifest.json index a7aa7884dc4..cc1c56b0927 100644 --- a/homeassistant/components/openai_conversation/manifest.json +++ b/homeassistant/components/openai_conversation/manifest.json @@ -8,5 +8,5 @@ "documentation": "https://www.home-assistant.io/integrations/openai_conversation", "integration_type": "service", "iot_class": "cloud_polling", - "requirements": ["openai==1.61.0"] + "requirements": ["openai==1.65.2"] } diff --git a/homeassistant/components/openai_conversation/services.yaml b/homeassistant/components/openai_conversation/services.yaml index 3db71cae383..75fa097f25d 100644 --- a/homeassistant/components/openai_conversation/services.yaml +++ b/homeassistant/components/openai_conversation/services.yaml @@ -38,3 +38,23 @@ generate_image: options: - "vivid" - "natural" +generate_content: + fields: + config_entry: + required: true + selector: + config_entry: + integration: openai_conversation + prompt: + required: true + selector: + text: + multiline: true + example: "Hello, how can I help you?" + filenames: + selector: + text: + multiline: true + example: | + - /path/to/file1.txt + - /path/to/file2.txt diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index aba4fdc3d40..c9d7ee112bd 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -72,6 +72,24 @@ "description": "The style of the generated image" } } + }, + "generate_content": { + "name": "Generate content", + "description": "Sends a conversational query to ChatGPT including any attached image files", + "fields": { + "config_entry": { + "name": "Config entry", + "description": "The config entry to use for this action" + }, + "prompt": { + "name": "Prompt", + "description": "The prompt to send" + }, + "filenames": { + "name": "Files", + "description": "List of files to upload" + } + } } }, "exceptions": { diff --git a/requirements_all.txt b/requirements_all.txt index 250d6597718..5947a0c5ad9 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1580,7 +1580,7 @@ open-garage==0.2.0 open-meteo==0.3.2 # homeassistant.components.openai_conversation -openai==1.61.0 +openai==1.65.2 # homeassistant.components.openerz openerz-api==0.3.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index c4c6463d48a..97af399a260 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1325,7 +1325,7 @@ open-garage==0.2.0 open-meteo==0.3.2 # homeassistant.components.openai_conversation -openai==1.61.0 +openai==1.65.2 # homeassistant.components.openerz openerz-api==0.3.0 diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index d78ce398c92..05a92d0b98e 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -1,18 +1,21 @@ """Tests for the OpenAI integration.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, mock_open, patch -from httpx import Response +from httpx import Request, Response from openai import ( APIConnectionError, AuthenticationError, BadRequestError, RateLimitError, ) +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.image import Image from openai.types.images_response import ImagesResponse import pytest +from homeassistant.components.openai_conversation import CONF_FILENAMES from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.setup import async_setup_component @@ -114,7 +117,9 @@ async def test_generate_image_service_error( patch( "openai.resources.images.AsyncImages.generate", side_effect=RateLimitError( - response=Response(status_code=None, request=""), + response=Response( + status_code=500, request=Request(method="GET", url="") + ), body=None, message="Reason", ), @@ -133,22 +138,60 @@ async def test_generate_image_service_error( ) +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_content_service_with_image_not_allowed_path( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> None: + """Test generate content service with an image in a not allowed path.""" + with ( + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=False), + pytest.raises( + HomeAssistantError, + match=( + "Cannot read `doorbell_snapshot.jpg`, no access to path; " + "`allowlist_external_dirs` may need to be adjusted in " + "`configuration.yaml`" + ), + ), + ): + await hass.services.async_call( + "openai_conversation", + "generate_content", + { + "config_entry": mock_config_entry.entry_id, + "prompt": "Describe this image from my doorbell camera", + "filenames": "doorbell_snapshot.jpg", + }, + blocking=True, + return_response=True, + ) + + +@pytest.mark.parametrize( + ("service_name", "error"), + [ + ("generate_image", "Invalid config entry provided. Got invalid_entry"), + ("generate_content", "Invalid config entry provided. Got invalid_entry"), + ], +) async def test_invalid_config_entry( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component, + service_name: str, + error: str, ) -> None: """Assert exception when invalid config entry is provided.""" service_data = { "prompt": "Picture of a dog", "config_entry": "invalid_entry", } - with pytest.raises( - ServiceValidationError, match="Invalid config entry provided. Got invalid_entry" - ): + with pytest.raises(ServiceValidationError, match=error): await hass.services.async_call( "openai_conversation", - "generate_image", + service_name, service_data, blocking=True, return_response=True, @@ -158,18 +201,29 @@ async def test_invalid_config_entry( @pytest.mark.parametrize( ("side_effect", "error"), [ - (APIConnectionError(request=None), "Connection error"), + ( + APIConnectionError(request=Request(method="GET", url="test")), + "Connection error", + ), ( AuthenticationError( - response=Response(status_code=None, request=""), body=None, message=None + response=Response( + status_code=500, request=Request(method="GET", url="test") + ), + body=None, + message="", ), "Invalid API key", ), ( BadRequestError( - response=Response(status_code=None, request=""), body=None, message=None + response=Response( + status_code=500, request=Request(method="GET", url="test") + ), + body=None, + message="", ), - "openai_conversation integration not ready yet: None", + "openai_conversation integration not ready yet", ), ], ) @@ -188,3 +242,241 @@ async def test_init_error( assert await async_setup_component(hass, "openai_conversation", {}) await hass.async_block_till_done() assert error in caplog.text + + +@pytest.mark.parametrize( + ("service_data", "expected_args", "number_of_files"), + [ + ( + {"prompt": "Picture of a dog", "filenames": []}, + { + "messages": [ + { + "content": [ + { + "type": "text", + "text": "Picture of a dog", + }, + ], + }, + ], + }, + 0, + ), + ( + {"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]}, + { + "messages": [ + { + "content": [ + { + "type": "text", + "text": "Picture of a dog", + }, + { + "type": "image_url", + "image_url": { + "url": "", + }, + }, + ], + }, + ], + }, + 1, + ), + ( + { + "prompt": "Picture of a dog", + "filenames": ["/a/b/c.jpg", "d/e/f.jpg"], + }, + { + "messages": [ + { + "content": [ + { + "type": "text", + "text": "Picture of a dog", + }, + { + "type": "image_url", + "image_url": { + "url": "", + }, + }, + { + "type": "image_url", + "image_url": { + "url": "", + }, + }, + ], + }, + ], + }, + 2, + ), + ], +) +async def test_generate_content_service( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + service_data, + expected_args, + number_of_files, +) -> None: + """Test generate content service.""" + service_data["config_entry"] = mock_config_entry.entry_id + expected_args["model"] = "gpt-4o-mini" + expected_args["n"] = 1 + expected_args["response_format"] = {"type": "json_object"} + expected_args["messages"][0]["role"] = "user" + + with ( + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create, + patch( + "base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"] + ) as mock_b64encode, + patch("builtins.open", mock_open(read_data="ABC")) as mock_file, + patch("pathlib.Path.exists", return_value=True), + patch.object(hass.config, "is_allowed_path", return_value=True), + ): + mock_create.return_value = ChatCompletion( + id="", + model="", + created=1700000000, + object="chat.completion", + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChatCompletionMessage( + role="assistant", + content="This is the response", + ), + ) + ], + ) + + response = await hass.services.async_call( + "openai_conversation", + "generate_content", + service_data, + blocking=True, + return_response=True, + ) + assert response == {"text": "This is the response"} + assert len(mock_create.mock_calls) == 1 + assert mock_create.mock_calls[0][2] == expected_args + assert mock_b64encode.call_count == number_of_files + for idx, file in enumerate(service_data[CONF_FILENAMES]): + assert mock_file.call_args_list[idx][0][0] == file + + +@pytest.mark.parametrize( + ( + "service_data", + "error", + "number_of_files", + "exists_side_effect", + "is_allowed_side_effect", + ), + [ + ( + {"prompt": "Picture of a dog", "filenames": ["/a/b/c.jpg"]}, + "`/a/b/c.jpg` does not exist", + 0, + [False], + [True], + ), + ( + { + "prompt": "Picture of a dog", + "filenames": ["/a/b/c.jpg", "d/e/f.png"], + }, + "Cannot read `d/e/f.png`, no access to path; `allowlist_external_dirs` may need to be adjusted in `configuration.yaml`", + 1, + [True, True], + [True, False], + ), + ( + {"prompt": "Not a picture of a dog", "filenames": ["/a/b/c.pdf"]}, + "Only images are supported by the OpenAI API,`/a/b/c.pdf` is not an image file", + 1, + [True], + [True], + ), + ], +) +async def test_generate_content_service_invalid( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + service_data, + error, + number_of_files, + exists_side_effect, + is_allowed_side_effect, +) -> None: + """Test generate content service.""" + service_data["config_entry"] = mock_config_entry.entry_id + + with ( + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_create, + patch( + "base64.b64encode", side_effect=[b"BASE64IMAGE1", b"BASE64IMAGE2"] + ) as mock_b64encode, + patch("builtins.open", mock_open(read_data="ABC")), + patch("pathlib.Path.exists", side_effect=exists_side_effect), + patch.object( + hass.config, "is_allowed_path", side_effect=is_allowed_side_effect + ), + ): + with pytest.raises(HomeAssistantError, match=error): + await hass.services.async_call( + "openai_conversation", + "generate_content", + service_data, + blocking=True, + return_response=True, + ) + assert len(mock_create.mock_calls) == 0 + assert mock_b64encode.call_count == number_of_files + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_content_service_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> None: + """Test generate content service handles errors.""" + with ( + patch( + "openai.resources.chat.completions.AsyncCompletions.create", + side_effect=RateLimitError( + response=Response( + status_code=417, request=Request(method="GET", url="") + ), + body=None, + message="Reason", + ), + ), + pytest.raises(HomeAssistantError, match="Error generating content: Reason"), + ): + await hass.services.async_call( + "openai_conversation", + "generate_content", + { + "config_entry": mock_config_entry.entry_id, + "prompt": "Image of an epic fail", + }, + blocking=True, + return_response=True, + )