mirror of
https://github.com/home-assistant/core.git
synced 2025-10-04 09:19:28 +00:00
350 lines
10 KiB
Python
350 lines
10 KiB
Python
"""AI tasks to be handled by agents."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
import io
|
|
import mimetypes
|
|
from pathlib import Path
|
|
import tempfile
|
|
from typing import Any
|
|
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import camera, conversation, image, media_source
|
|
from homeassistant.components.http.auth import async_sign_path
|
|
from homeassistant.core import HomeAssistant, ServiceResponse, callback
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import llm
|
|
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
|
|
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
|
|
|
|
from .const import (
|
|
DATA_COMPONENT,
|
|
DATA_MEDIA_SOURCE,
|
|
DATA_PREFERENCES,
|
|
DOMAIN,
|
|
IMAGE_DIR,
|
|
IMAGE_EXPIRY_TIME,
|
|
AITaskEntityFeature,
|
|
)
|
|
|
|
|
|
def _save_camera_snapshot(image_data: camera.Image | image.Image) -> Path:
|
|
"""Save camera snapshot to temp file."""
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="wb",
|
|
suffix=mimetypes.guess_extension(image_data.content_type, False),
|
|
delete=False,
|
|
) as temp_file:
|
|
temp_file.write(image_data.content)
|
|
return Path(temp_file.name)
|
|
|
|
|
|
async def _resolve_attachments(
|
|
hass: HomeAssistant,
|
|
session: ChatSession,
|
|
attachments: list[dict] | None = None,
|
|
) -> list[conversation.Attachment]:
|
|
"""Resolve attachments for a task."""
|
|
resolved_attachments: list[conversation.Attachment] = []
|
|
created_files: list[Path] = []
|
|
|
|
for attachment in attachments or []:
|
|
media_content_id = attachment["media_content_id"]
|
|
|
|
# Special case for certain media sources
|
|
for integration in camera, image:
|
|
media_source_prefix = f"media-source://{integration.DOMAIN}/"
|
|
if not media_content_id.startswith(media_source_prefix):
|
|
continue
|
|
|
|
# Extract entity_id from the media content ID
|
|
entity_id = media_content_id.removeprefix(media_source_prefix)
|
|
|
|
# Get snapshot from entity
|
|
image_data = await integration.async_get_image(hass, entity_id)
|
|
|
|
temp_filename = await hass.async_add_executor_job(
|
|
_save_camera_snapshot, image_data
|
|
)
|
|
created_files.append(temp_filename)
|
|
|
|
resolved_attachments.append(
|
|
conversation.Attachment(
|
|
media_content_id=media_content_id,
|
|
mime_type=image_data.content_type,
|
|
path=temp_filename,
|
|
)
|
|
)
|
|
break
|
|
else:
|
|
# Handle regular media sources
|
|
media = await media_source.async_resolve_media(hass, media_content_id, None)
|
|
if media.path is None:
|
|
raise HomeAssistantError(
|
|
"Only local attachments are currently supported"
|
|
)
|
|
resolved_attachments.append(
|
|
conversation.Attachment(
|
|
media_content_id=media_content_id,
|
|
mime_type=media.mime_type,
|
|
path=media.path,
|
|
)
|
|
)
|
|
|
|
if not created_files:
|
|
return resolved_attachments
|
|
|
|
def cleanup_files() -> None:
|
|
"""Cleanup temporary files."""
|
|
for file in created_files:
|
|
file.unlink(missing_ok=True)
|
|
|
|
@callback
|
|
def cleanup_files_callback() -> None:
|
|
"""Cleanup temporary files."""
|
|
hass.async_add_executor_job(cleanup_files)
|
|
|
|
session.async_on_cleanup(cleanup_files_callback)
|
|
|
|
return resolved_attachments
|
|
|
|
|
|
async def async_generate_data(
|
|
hass: HomeAssistant,
|
|
*,
|
|
task_name: str,
|
|
entity_id: str | None = None,
|
|
instructions: str,
|
|
structure: vol.Schema | None = None,
|
|
attachments: list[dict] | None = None,
|
|
llm_api: llm.API | None = None,
|
|
) -> GenDataTaskResult:
|
|
"""Run a data generation task in the AI Task integration."""
|
|
if entity_id is None:
|
|
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
|
|
|
|
if entity_id is None:
|
|
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
|
|
|
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
|
if entity is None:
|
|
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
|
|
|
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
|
|
raise HomeAssistantError(
|
|
f"AI Task entity {entity_id} does not support generating data"
|
|
)
|
|
|
|
if (
|
|
attachments
|
|
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
|
|
):
|
|
raise HomeAssistantError(
|
|
f"AI Task entity {entity_id} does not support attachments"
|
|
)
|
|
|
|
with async_get_chat_session(hass) as session:
|
|
resolved_attachments = await _resolve_attachments(hass, session, attachments)
|
|
|
|
return await entity.internal_async_generate_data(
|
|
session,
|
|
GenDataTask(
|
|
name=task_name,
|
|
instructions=instructions,
|
|
structure=structure,
|
|
attachments=resolved_attachments or None,
|
|
llm_api=llm_api,
|
|
),
|
|
)
|
|
|
|
|
|
async def async_generate_image(
|
|
hass: HomeAssistant,
|
|
*,
|
|
task_name: str,
|
|
entity_id: str | None = None,
|
|
instructions: str,
|
|
attachments: list[dict] | None = None,
|
|
) -> ServiceResponse:
|
|
"""Run an image generation task in the AI Task integration."""
|
|
if entity_id is None:
|
|
entity_id = hass.data[DATA_PREFERENCES].gen_image_entity_id
|
|
|
|
if entity_id is None:
|
|
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
|
|
|
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
|
if entity is None:
|
|
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
|
|
|
if AITaskEntityFeature.GENERATE_IMAGE not in entity.supported_features:
|
|
raise HomeAssistantError(
|
|
f"AI Task entity {entity_id} does not support generating images"
|
|
)
|
|
|
|
if (
|
|
attachments
|
|
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
|
|
):
|
|
raise HomeAssistantError(
|
|
f"AI Task entity {entity_id} does not support attachments"
|
|
)
|
|
|
|
with async_get_chat_session(hass) as session:
|
|
resolved_attachments = await _resolve_attachments(hass, session, attachments)
|
|
|
|
task_result = await entity.internal_async_generate_image(
|
|
session,
|
|
GenImageTask(
|
|
name=task_name,
|
|
instructions=instructions,
|
|
attachments=resolved_attachments or None,
|
|
),
|
|
)
|
|
|
|
service_result = task_result.as_dict()
|
|
image_data = service_result.pop("image_data")
|
|
if service_result.get("revised_prompt") is None:
|
|
service_result["revised_prompt"] = instructions
|
|
|
|
source = hass.data[DATA_MEDIA_SOURCE]
|
|
|
|
current_time = datetime.now()
|
|
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
|
|
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
|
|
|
|
image_file = ImageData(
|
|
filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}",
|
|
file=io.BytesIO(image_data),
|
|
content_type=task_result.mime_type,
|
|
)
|
|
|
|
target_folder = media_source.MediaSourceItem.from_uri(
|
|
hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None
|
|
)
|
|
|
|
service_result["media_source_id"] = await source.async_upload_media(
|
|
target_folder, image_file
|
|
)
|
|
|
|
item = media_source.MediaSourceItem.from_uri(
|
|
hass, service_result["media_source_id"], None
|
|
)
|
|
service_result["url"] = async_sign_path(
|
|
hass,
|
|
(await source.async_resolve_media(item)).url,
|
|
timedelta(seconds=IMAGE_EXPIRY_TIME),
|
|
)
|
|
|
|
return service_result
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class GenDataTask:
|
|
"""Gen data task to be processed."""
|
|
|
|
name: str
|
|
"""Name of the task."""
|
|
|
|
instructions: str
|
|
"""Instructions on what needs to be done."""
|
|
|
|
structure: vol.Schema | None = None
|
|
"""Optional structure for the data to be generated."""
|
|
|
|
attachments: list[conversation.Attachment] | None = None
|
|
"""List of attachments to go along the instructions."""
|
|
|
|
llm_api: llm.API | None = None
|
|
"""API to provide to the LLM."""
|
|
|
|
def __str__(self) -> str:
|
|
"""Return task as a string."""
|
|
return f"<GenDataTask {self.name}: {id(self)}>"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class GenDataTaskResult:
|
|
"""Result of gen data task."""
|
|
|
|
conversation_id: str
|
|
"""Unique identifier for the conversation."""
|
|
|
|
data: Any
|
|
"""Data generated by the task."""
|
|
|
|
def as_dict(self) -> dict[str, Any]:
|
|
"""Return result as a dict."""
|
|
return {
|
|
"conversation_id": self.conversation_id,
|
|
"data": self.data,
|
|
}
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class GenImageTask:
|
|
"""Gen image task to be processed."""
|
|
|
|
name: str
|
|
"""Name of the task."""
|
|
|
|
instructions: str
|
|
"""Instructions on what needs to be done."""
|
|
|
|
attachments: list[conversation.Attachment] | None = None
|
|
"""List of attachments to go along the instructions."""
|
|
|
|
def __str__(self) -> str:
|
|
"""Return task as a string."""
|
|
return f"<GenImageTask {self.name}: {id(self)}>"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class GenImageTaskResult:
|
|
"""Result of gen image task."""
|
|
|
|
image_data: bytes
|
|
"""Raw image data generated by the model."""
|
|
|
|
conversation_id: str
|
|
"""Unique identifier for the conversation."""
|
|
|
|
mime_type: str
|
|
"""MIME type of the generated image."""
|
|
|
|
width: int | None = None
|
|
"""Width of the generated image, if available."""
|
|
|
|
height: int | None = None
|
|
"""Height of the generated image, if available."""
|
|
|
|
model: str | None = None
|
|
"""Model used to generate the image, if available."""
|
|
|
|
revised_prompt: str | None = None
|
|
"""Revised prompt used to generate the image, if applicable."""
|
|
|
|
def as_dict(self) -> dict[str, Any]:
|
|
"""Return result as a dict."""
|
|
return {
|
|
"image_data": self.image_data,
|
|
"conversation_id": self.conversation_id,
|
|
"mime_type": self.mime_type,
|
|
"width": self.width,
|
|
"height": self.height,
|
|
"model": self.model,
|
|
"revised_prompt": self.revised_prompt,
|
|
}
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ImageData:
|
|
"""Implementation of media_source.local_source.UploadedFile protocol."""
|
|
|
|
filename: str
|
|
file: io.IOBase
|
|
content_type: str
|