Files
core/homeassistant/components/ai_task/task.py
2025-09-17 10:35:55 +02:00

350 lines
10 KiB
Python

"""AI tasks to be handled by agents."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
import io
import mimetypes
from pathlib import Path
import tempfile
from typing import Any
import voluptuous as vol
from homeassistant.components import camera, conversation, image, media_source
from homeassistant.components.http.auth import async_sign_path
from homeassistant.core import HomeAssistant, ServiceResponse, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
from .const import (
DATA_COMPONENT,
DATA_MEDIA_SOURCE,
DATA_PREFERENCES,
DOMAIN,
IMAGE_DIR,
IMAGE_EXPIRY_TIME,
AITaskEntityFeature,
)
def _save_camera_snapshot(image_data: camera.Image | image.Image) -> Path:
"""Save camera snapshot to temp file."""
with tempfile.NamedTemporaryFile(
mode="wb",
suffix=mimetypes.guess_extension(image_data.content_type, False),
delete=False,
) as temp_file:
temp_file.write(image_data.content)
return Path(temp_file.name)
async def _resolve_attachments(
hass: HomeAssistant,
session: ChatSession,
attachments: list[dict] | None = None,
) -> list[conversation.Attachment]:
"""Resolve attachments for a task."""
resolved_attachments: list[conversation.Attachment] = []
created_files: list[Path] = []
for attachment in attachments or []:
media_content_id = attachment["media_content_id"]
# Special case for certain media sources
for integration in camera, image:
media_source_prefix = f"media-source://{integration.DOMAIN}/"
if not media_content_id.startswith(media_source_prefix):
continue
# Extract entity_id from the media content ID
entity_id = media_content_id.removeprefix(media_source_prefix)
# Get snapshot from entity
image_data = await integration.async_get_image(hass, entity_id)
temp_filename = await hass.async_add_executor_job(
_save_camera_snapshot, image_data
)
created_files.append(temp_filename)
resolved_attachments.append(
conversation.Attachment(
media_content_id=media_content_id,
mime_type=image_data.content_type,
path=temp_filename,
)
)
break
else:
# Handle regular media sources
media = await media_source.async_resolve_media(hass, media_content_id, None)
if media.path is None:
raise HomeAssistantError(
"Only local attachments are currently supported"
)
resolved_attachments.append(
conversation.Attachment(
media_content_id=media_content_id,
mime_type=media.mime_type,
path=media.path,
)
)
if not created_files:
return resolved_attachments
def cleanup_files() -> None:
"""Cleanup temporary files."""
for file in created_files:
file.unlink(missing_ok=True)
@callback
def cleanup_files_callback() -> None:
"""Cleanup temporary files."""
hass.async_add_executor_job(cleanup_files)
session.async_on_cleanup(cleanup_files_callback)
return resolved_attachments
async def async_generate_data(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
attachments: list[dict] | None = None,
llm_api: llm.API | None = None,
) -> GenDataTaskResult:
"""Run a data generation task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating data"
)
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
with async_get_chat_session(hass) as session:
resolved_attachments = await _resolve_attachments(hass, session, attachments)
return await entity.internal_async_generate_data(
session,
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
attachments=resolved_attachments or None,
llm_api=llm_api,
),
)
async def async_generate_image(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
attachments: list[dict] | None = None,
) -> ServiceResponse:
"""Run an image generation task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_image_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_IMAGE not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating images"
)
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
with async_get_chat_session(hass) as session:
resolved_attachments = await _resolve_attachments(hass, session, attachments)
task_result = await entity.internal_async_generate_image(
session,
GenImageTask(
name=task_name,
instructions=instructions,
attachments=resolved_attachments or None,
),
)
service_result = task_result.as_dict()
image_data = service_result.pop("image_data")
if service_result.get("revised_prompt") is None:
service_result["revised_prompt"] = instructions
source = hass.data[DATA_MEDIA_SOURCE]
current_time = datetime.now()
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
image_file = ImageData(
filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}",
file=io.BytesIO(image_data),
content_type=task_result.mime_type,
)
target_folder = media_source.MediaSourceItem.from_uri(
hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None
)
service_result["media_source_id"] = await source.async_upload_media(
target_folder, image_file
)
item = media_source.MediaSourceItem.from_uri(
hass, service_result["media_source_id"], None
)
service_result["url"] = async_sign_path(
hass,
(await source.async_resolve_media(item)).url,
timedelta(seconds=IMAGE_EXPIRY_TIME),
)
return service_result
@dataclass(slots=True)
class GenDataTask:
"""Gen data task to be processed."""
name: str
"""Name of the task."""
instructions: str
"""Instructions on what needs to be done."""
structure: vol.Schema | None = None
"""Optional structure for the data to be generated."""
attachments: list[conversation.Attachment] | None = None
"""List of attachments to go along the instructions."""
llm_api: llm.API | None = None
"""API to provide to the LLM."""
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>"
@dataclass(slots=True)
class GenDataTaskResult:
"""Result of gen data task."""
conversation_id: str
"""Unique identifier for the conversation."""
data: Any
"""Data generated by the task."""
def as_dict(self) -> dict[str, Any]:
"""Return result as a dict."""
return {
"conversation_id": self.conversation_id,
"data": self.data,
}
@dataclass(slots=True)
class GenImageTask:
"""Gen image task to be processed."""
name: str
"""Name of the task."""
instructions: str
"""Instructions on what needs to be done."""
attachments: list[conversation.Attachment] | None = None
"""List of attachments to go along the instructions."""
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenImageTask {self.name}: {id(self)}>"
@dataclass(slots=True)
class GenImageTaskResult:
"""Result of gen image task."""
image_data: bytes
"""Raw image data generated by the model."""
conversation_id: str
"""Unique identifier for the conversation."""
mime_type: str
"""MIME type of the generated image."""
width: int | None = None
"""Width of the generated image, if available."""
height: int | None = None
"""Height of the generated image, if available."""
model: str | None = None
"""Model used to generate the image, if available."""
revised_prompt: str | None = None
"""Revised prompt used to generate the image, if applicable."""
def as_dict(self) -> dict[str, Any]:
"""Return result as a dict."""
return {
"image_data": self.image_data,
"conversation_id": self.conversation_id,
"mime_type": self.mime_type,
"width": self.width,
"height": self.height,
"model": self.model,
"revised_prompt": self.revised_prompt,
}
@dataclass(slots=True)
class ImageData:
"""Implementation of media_source.local_source.UploadedFile protocol."""
filename: str
file: io.IOBase
content_type: str