mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Migrate openai_conversation to entry.runtime_data
(#118535)
* switch to entry.runtime_data * check for missing config entry * Update homeassistant/components/openai_conversation/__init__.py --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
a59c890779
commit
4998fe5e6d
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -13,7 +15,11 @@ from homeassistant.core import (
|
|||||||
ServiceResponse,
|
ServiceResponse,
|
||||||
SupportsResponse,
|
SupportsResponse,
|
||||||
)
|
)
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
|
from homeassistant.exceptions import (
|
||||||
|
ConfigEntryNotReady,
|
||||||
|
HomeAssistantError,
|
||||||
|
ServiceValidationError,
|
||||||
|
)
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
config_validation as cv,
|
config_validation as cv,
|
||||||
issue_registry as ir,
|
issue_registry as ir,
|
||||||
@ -27,13 +33,25 @@ SERVICE_GENERATE_IMAGE = "generate_image"
|
|||||||
PLATFORMS = (Platform.CONVERSATION,)
|
PLATFORMS = (Platform.CONVERSATION,)
|
||||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||||
|
|
||||||
|
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up OpenAI Conversation."""
|
"""Set up OpenAI Conversation."""
|
||||||
|
|
||||||
async def render_image(call: ServiceCall) -> ServiceResponse:
|
async def render_image(call: ServiceCall) -> ServiceResponse:
|
||||||
"""Render an image with dall-e."""
|
"""Render an image with dall-e."""
|
||||||
client = hass.data[DOMAIN][call.data["config_entry"]]
|
entry_id = call.data["config_entry"]
|
||||||
|
entry = hass.config_entries.async_get_entry(entry_id)
|
||||||
|
|
||||||
|
if entry is None or entry.domain != DOMAIN:
|
||||||
|
raise ServiceValidationError(
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="invalid_config_entry",
|
||||||
|
translation_placeholders={"config_entry": entry_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
client: openai.AsyncClient = entry.runtime_data
|
||||||
|
|
||||||
if call.data["size"] in ("256", "512", "1024"):
|
if call.data["size"] in ("256", "512", "1024"):
|
||||||
ir.async_create_issue(
|
ir.async_create_issue(
|
||||||
@ -51,6 +69,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
else:
|
else:
|
||||||
size = call.data["size"]
|
size = call.data["size"]
|
||||||
|
|
||||||
|
size = cast(
|
||||||
|
Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||||
|
size,
|
||||||
|
) # size is selector, so no need to check further
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.images.generate(
|
response = await client.images.generate(
|
||||||
model="dall-e-3",
|
model="dall-e-3",
|
||||||
@ -90,7 +113,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bool:
|
||||||
"""Set up OpenAI Conversation from a config entry."""
|
"""Set up OpenAI Conversation from a config entry."""
|
||||||
client = openai.AsyncOpenAI(api_key=entry.data[CONF_API_KEY])
|
client = openai.AsyncOpenAI(api_key=entry.data[CONF_API_KEY])
|
||||||
try:
|
try:
|
||||||
@ -101,7 +124,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
except openai.OpenAIError as err:
|
except openai.OpenAIError as err:
|
||||||
raise ConfigEntryNotReady(err) from err
|
raise ConfigEntryNotReady(err) from err
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
entry.runtime_data = client
|
||||||
|
|
||||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
|
|
||||||
@ -110,8 +133,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload OpenAI."""
|
"""Unload OpenAI."""
|
||||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||||
return False
|
|
||||||
|
|
||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
|
||||||
return True
|
|
||||||
|
@ -22,7 +22,6 @@ from voluptuous_openapi import convert
|
|||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
@ -30,6 +29,7 @@ from homeassistant.helpers import device_registry as dr, intent, llm, template
|
|||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid
|
from homeassistant.util import ulid
|
||||||
|
|
||||||
|
from . import OpenAIConfigEntry
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
@ -50,7 +50,7 @@ MAX_TOOL_ITERATIONS = 10
|
|||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config_entry: ConfigEntry,
|
config_entry: OpenAIConfigEntry,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up conversation entities."""
|
"""Set up conversation entities."""
|
||||||
@ -74,7 +74,7 @@ class OpenAIConversationEntity(
|
|||||||
_attr_has_entity_name = True
|
_attr_has_entity_name = True
|
||||||
_attr_name = None
|
_attr_name = None
|
||||||
|
|
||||||
def __init__(self, entry: ConfigEntry) -> None:
|
def __init__(self, entry: OpenAIConfigEntry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
self.history: dict[str, list[ChatCompletionMessageParam]] = {}
|
||||||
@ -203,7 +203,7 @@ class OpenAIConversationEntity(
|
|||||||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||||
)
|
)
|
||||||
|
|
||||||
client: openai.AsyncClient = self.hass.data[DOMAIN][self.entry.entry_id]
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
|
@ -60,6 +60,11 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"exceptions": {
|
||||||
|
"invalid_config_entry": {
|
||||||
|
"message": "Invalid config entry provided. Got {config_entry}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"issues": {
|
"issues": {
|
||||||
"image_size_deprecated_format": {
|
"image_size_deprecated_format": {
|
||||||
"title": "Deprecated size format for image generation service",
|
"title": "Deprecated size format for image generation service",
|
||||||
|
@ -14,7 +14,7 @@ from openai.types.images_response import ImagesResponse
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
@ -160,6 +160,28 @@ async def test_generate_image_service_error(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_config_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Assert exception when invalid config entry is provided."""
|
||||||
|
service_data = {
|
||||||
|
"prompt": "Picture of a dog",
|
||||||
|
"config_entry": "invalid_entry",
|
||||||
|
}
|
||||||
|
with pytest.raises(
|
||||||
|
ServiceValidationError, match="Invalid config entry provided. Got invalid_entry"
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"openai_conversation",
|
||||||
|
"generate_image",
|
||||||
|
service_data,
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
[
|
[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user