diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index 2a91f1b1b38..0ba7b53795b 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Literal, cast + import openai import voluptuous as vol @@ -13,7 +15,11 @@ from homeassistant.core import ( ServiceResponse, SupportsResponse, ) -from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError +from homeassistant.exceptions import ( + ConfigEntryNotReady, + HomeAssistantError, + ServiceValidationError, +) from homeassistant.helpers import ( config_validation as cv, issue_registry as ir, @@ -27,13 +33,25 @@ SERVICE_GENERATE_IMAGE = "generate_image" PLATFORMS = (Platform.CONVERSATION,) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) +type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient] + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up OpenAI Conversation.""" async def render_image(call: ServiceCall) -> ServiceResponse: """Render an image with dall-e.""" - client = hass.data[DOMAIN][call.data["config_entry"]] + entry_id = call.data["config_entry"] + entry = hass.config_entries.async_get_entry(entry_id) + + if entry is None or entry.domain != DOMAIN: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="invalid_config_entry", + translation_placeholders={"config_entry": entry_id}, + ) + + client: openai.AsyncClient = entry.runtime_data if call.data["size"] in ("256", "512", "1024"): ir.async_create_issue( @@ -51,6 +69,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: else: size = call.data["size"] + size = cast( + Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], + size, + ) # size is selector, so no need to check further + try: response = await client.images.generate( model="dall-e-3", @@ -90,7 +113,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bool: """Set up OpenAI Conversation from a config entry.""" client = openai.AsyncOpenAI(api_key=entry.data[CONF_API_KEY]) try: @@ -101,7 +124,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: except openai.OpenAIError as err: raise ConfigEntryNotReady(err) from err - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client + entry.runtime_data = client await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) @@ -110,8 +133,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload OpenAI.""" - if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): - return False - - hass.data[DOMAIN].pop(entry.entry_id) - return True + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 8de146e0851..29228ba8e3b 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -22,7 +22,6 @@ from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.components.conversation import trace -from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError, TemplateError @@ -30,6 +29,7 @@ from homeassistant.helpers import device_registry as dr, intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid +from . import OpenAIConfigEntry from .const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, @@ -50,7 +50,7 @@ MAX_TOOL_ITERATIONS = 10 async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: OpenAIConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up conversation entities.""" @@ -74,7 +74,7 @@ class OpenAIConversationEntity( _attr_has_entity_name = True _attr_name = None - def __init__(self, entry: ConfigEntry) -> None: + def __init__(self, entry: OpenAIConfigEntry) -> None: """Initialize the agent.""" self.entry = entry self.history: dict[str, list[ChatCompletionMessageParam]] = {} @@ -203,7 +203,7 @@ class OpenAIConversationEntity( trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} ) - client: openai.AsyncClient = self.hass.data[DOMAIN][self.entry.entry_id] + client = self.entry.runtime_data # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 1e93c60b6a9..c5d42eb9521 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -60,6 +60,11 @@ } } }, + "exceptions": { + "invalid_config_entry": { + "message": "Invalid config entry provided. Got {config_entry}" + } + }, "issues": { "image_size_deprecated_format": { "title": "Deprecated size format for image generation service", diff --git a/tests/components/openai_conversation/test_init.py b/tests/components/openai_conversation/test_init.py index f03013556c7..c9431aa1083 100644 --- a/tests/components/openai_conversation/test_init.py +++ b/tests/components/openai_conversation/test_init.py @@ -14,7 +14,7 @@ from openai.types.images_response import ImagesResponse import pytest from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -160,6 +160,28 @@ async def test_generate_image_service_error( ) +async def test_invalid_config_entry( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Assert exception when invalid config entry is provided.""" + service_data = { + "prompt": "Picture of a dog", + "config_entry": "invalid_entry", + } + with pytest.raises( + ServiceValidationError, match="Invalid config entry provided. Got invalid_entry" + ): + await hass.services.async_call( + "openai_conversation", + "generate_image", + service_data, + blocking=True, + return_response=True, + ) + + @pytest.mark.parametrize( ("side_effect", "error"), [