mirror of
https://github.com/home-assistant/core.git
synced 2025-10-15 22:59:58 +00:00
156 lines
5.2 KiB
Python
156 lines
5.2 KiB
Python
"""AI Task integration for OpenAI."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
from json import JSONDecodeError
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
from openai.types.responses.response_output_item import ImageGenerationCall
|
|
|
|
from homeassistant.components import ai_task, conversation
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
|
from homeassistant.util.json import json_loads
|
|
|
|
from .const import (
|
|
CONF_CHAT_MODEL,
|
|
CONF_IMAGE_MODEL,
|
|
RECOMMENDED_CHAT_MODEL,
|
|
RECOMMENDED_IMAGE_MODEL,
|
|
UNSUPPORTED_IMAGE_MODELS,
|
|
)
|
|
from .entity import OpenAIBaseLLMEntity
|
|
|
|
if TYPE_CHECKING:
|
|
from homeassistant.config_entries import ConfigSubentry
|
|
|
|
from . import OpenAIConfigEntry
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
|
) -> None:
|
|
"""Set up AI Task entities."""
|
|
for subentry in config_entry.subentries.values():
|
|
if subentry.subentry_type != "ai_task_data":
|
|
continue
|
|
|
|
async_add_entities(
|
|
[OpenAITaskEntity(config_entry, subentry)],
|
|
config_subentry_id=subentry.subentry_id,
|
|
)
|
|
|
|
|
|
class OpenAITaskEntity(
|
|
ai_task.AITaskEntity,
|
|
OpenAIBaseLLMEntity,
|
|
):
|
|
"""OpenAI AI Task entity."""
|
|
|
|
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
|
"""Initialize the entity."""
|
|
super().__init__(entry, subentry)
|
|
self._attr_supported_features = (
|
|
ai_task.AITaskEntityFeature.GENERATE_DATA
|
|
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
|
)
|
|
model = self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
|
if not model.startswith(tuple(UNSUPPORTED_IMAGE_MODELS)):
|
|
self._attr_supported_features |= ai_task.AITaskEntityFeature.GENERATE_IMAGE
|
|
|
|
async def _async_generate_data(
|
|
self,
|
|
task: ai_task.GenDataTask,
|
|
chat_log: conversation.ChatLog,
|
|
) -> ai_task.GenDataTaskResult:
|
|
"""Handle a generate data task."""
|
|
await self._async_handle_chat_log(chat_log, task.name, task.structure)
|
|
|
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
|
raise HomeAssistantError(
|
|
"Last content in chat log is not an AssistantContent"
|
|
)
|
|
|
|
text = chat_log.content[-1].content or ""
|
|
|
|
if not task.structure:
|
|
return ai_task.GenDataTaskResult(
|
|
conversation_id=chat_log.conversation_id,
|
|
data=text,
|
|
)
|
|
try:
|
|
data = json_loads(text)
|
|
except JSONDecodeError as err:
|
|
_LOGGER.error(
|
|
"Failed to parse JSON response: %s. Response: %s",
|
|
err,
|
|
text,
|
|
)
|
|
raise HomeAssistantError("Error with OpenAI structured response") from err
|
|
|
|
return ai_task.GenDataTaskResult(
|
|
conversation_id=chat_log.conversation_id,
|
|
data=data,
|
|
)
|
|
|
|
async def _async_generate_image(
|
|
self,
|
|
task: ai_task.GenImageTask,
|
|
chat_log: conversation.ChatLog,
|
|
) -> ai_task.GenImageTaskResult:
|
|
"""Handle a generate image task."""
|
|
await self._async_handle_chat_log(chat_log, task.name, force_image=True)
|
|
|
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
|
raise HomeAssistantError(
|
|
"Last content in chat log is not an AssistantContent"
|
|
)
|
|
|
|
image_call: ImageGenerationCall | None = None
|
|
for content in reversed(chat_log.content):
|
|
if not isinstance(content, conversation.AssistantContent):
|
|
break
|
|
if isinstance(content.native, ImageGenerationCall):
|
|
if image_call is None or image_call.result is None:
|
|
image_call = content.native
|
|
else: # Remove image data from chat log to save memory
|
|
content.native.result = None
|
|
|
|
if image_call is None or image_call.result is None:
|
|
raise HomeAssistantError("No image returned")
|
|
|
|
image_data = base64.b64decode(image_call.result)
|
|
image_call.result = None
|
|
|
|
if hasattr(image_call, "output_format") and (
|
|
output_format := image_call.output_format
|
|
):
|
|
mime_type = f"image/{output_format}"
|
|
else:
|
|
mime_type = "image/png"
|
|
|
|
if hasattr(image_call, "size") and (size := image_call.size):
|
|
width, height = tuple(size.split("x"))
|
|
else:
|
|
width, height = None, None
|
|
|
|
return ai_task.GenImageTaskResult(
|
|
image_data=image_data,
|
|
conversation_id=chat_log.conversation_id,
|
|
mime_type=mime_type,
|
|
width=int(width) if width else None,
|
|
height=int(height) if height else None,
|
|
model=self.subentry.data.get(CONF_IMAGE_MODEL, RECOMMENDED_IMAGE_MODEL),
|
|
revised_prompt=image_call.revised_prompt
|
|
if hasattr(image_call, "revised_prompt")
|
|
else None,
|
|
)
|