Files
core/homeassistant/components/openai_conversation/ai_task.py
2025-10-12 12:44:38 -04:00

156 lines
5.2 KiB
Python

"""AI Task integration for OpenAI."""
from __future__ import annotations
import base64
from json import JSONDecodeError
import logging
from typing import TYPE_CHECKING
from openai.types.responses.response_output_item import ImageGenerationCall
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .const import (
CONF_CHAT_MODEL,
CONF_IMAGE_MODEL,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_IMAGE_MODEL,
UNSUPPORTED_IMAGE_MODELS,
)
from .entity import OpenAIBaseLLMEntity
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigSubentry
from . import OpenAIConfigEntry
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue
async_add_entities(
[OpenAITaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class OpenAITaskEntity(
ai_task.AITaskEntity,
OpenAIBaseLLMEntity,
):
"""OpenAI AI Task entity."""
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the entity."""
super().__init__(entry, subentry)
self._attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
model = self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
if not model.startswith(tuple(UNSUPPORTED_IMAGE_MODELS)):
self._attr_supported_features |= ai_task.AITaskEntityFeature.GENERATE_IMAGE
async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.name, task.structure)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)
text = chat_log.content[-1].content or ""
if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with OpenAI structured response") from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
async def _async_generate_image(
self,
task: ai_task.GenImageTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenImageTaskResult:
"""Handle a generate image task."""
await self._async_handle_chat_log(chat_log, task.name, force_image=True)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)
image_call: ImageGenerationCall | None = None
for content in reversed(chat_log.content):
if not isinstance(content, conversation.AssistantContent):
break
if isinstance(content.native, ImageGenerationCall):
if image_call is None or image_call.result is None:
image_call = content.native
else: # Remove image data from chat log to save memory
content.native.result = None
if image_call is None or image_call.result is None:
raise HomeAssistantError("No image returned")
image_data = base64.b64decode(image_call.result)
image_call.result = None
if hasattr(image_call, "output_format") and (
output_format := image_call.output_format
):
mime_type = f"image/{output_format}"
else:
mime_type = "image/png"
if hasattr(image_call, "size") and (size := image_call.size):
width, height = tuple(size.split("x"))
else:
width, height = None, None
return ai_task.GenImageTaskResult(
image_data=image_data,
conversation_id=chat_log.conversation_id,
mime_type=mime_type,
width=int(width) if width else None,
height=int(height) if height else None,
model=self.subentry.data.get(CONF_IMAGE_MODEL, RECOMMENDED_IMAGE_MODEL),
revised_prompt=image_call.revised_prompt
if hasattr(image_call, "revised_prompt")
else None,
)