mirror of
				https://github.com/home-assistant/core.git
				synced 2025-11-04 08:29:37 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			350 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			350 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""AI tasks to be handled by agents."""
 | 
						|
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
from dataclasses import dataclass
 | 
						|
from datetime import datetime, timedelta
 | 
						|
import io
 | 
						|
import mimetypes
 | 
						|
from pathlib import Path
 | 
						|
import tempfile
 | 
						|
from typing import Any
 | 
						|
 | 
						|
import voluptuous as vol
 | 
						|
 | 
						|
from homeassistant.components import camera, conversation, image, media_source
 | 
						|
from homeassistant.components.http.auth import async_sign_path
 | 
						|
from homeassistant.core import HomeAssistant, ServiceResponse, callback
 | 
						|
from homeassistant.exceptions import HomeAssistantError
 | 
						|
from homeassistant.helpers import llm
 | 
						|
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
 | 
						|
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
 | 
						|
 | 
						|
from .const import (
 | 
						|
    DATA_COMPONENT,
 | 
						|
    DATA_MEDIA_SOURCE,
 | 
						|
    DATA_PREFERENCES,
 | 
						|
    DOMAIN,
 | 
						|
    IMAGE_DIR,
 | 
						|
    IMAGE_EXPIRY_TIME,
 | 
						|
    AITaskEntityFeature,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
def _save_camera_snapshot(image_data: camera.Image | image.Image) -> Path:
 | 
						|
    """Save camera snapshot to temp file."""
 | 
						|
    with tempfile.NamedTemporaryFile(
 | 
						|
        mode="wb",
 | 
						|
        suffix=mimetypes.guess_extension(image_data.content_type, False),
 | 
						|
        delete=False,
 | 
						|
    ) as temp_file:
 | 
						|
        temp_file.write(image_data.content)
 | 
						|
        return Path(temp_file.name)
 | 
						|
 | 
						|
 | 
						|
async def _resolve_attachments(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    session: ChatSession,
 | 
						|
    attachments: list[dict] | None = None,
 | 
						|
) -> list[conversation.Attachment]:
 | 
						|
    """Resolve attachments for a task."""
 | 
						|
    resolved_attachments: list[conversation.Attachment] = []
 | 
						|
    created_files: list[Path] = []
 | 
						|
 | 
						|
    for attachment in attachments or []:
 | 
						|
        media_content_id = attachment["media_content_id"]
 | 
						|
 | 
						|
        # Special case for certain media sources
 | 
						|
        for integration in camera, image:
 | 
						|
            media_source_prefix = f"media-source://{integration.DOMAIN}/"
 | 
						|
            if not media_content_id.startswith(media_source_prefix):
 | 
						|
                continue
 | 
						|
 | 
						|
            # Extract entity_id from the media content ID
 | 
						|
            entity_id = media_content_id.removeprefix(media_source_prefix)
 | 
						|
 | 
						|
            # Get snapshot from entity
 | 
						|
            image_data = await integration.async_get_image(hass, entity_id)
 | 
						|
 | 
						|
            temp_filename = await hass.async_add_executor_job(
 | 
						|
                _save_camera_snapshot, image_data
 | 
						|
            )
 | 
						|
            created_files.append(temp_filename)
 | 
						|
 | 
						|
            resolved_attachments.append(
 | 
						|
                conversation.Attachment(
 | 
						|
                    media_content_id=media_content_id,
 | 
						|
                    mime_type=image_data.content_type,
 | 
						|
                    path=temp_filename,
 | 
						|
                )
 | 
						|
            )
 | 
						|
            break
 | 
						|
        else:
 | 
						|
            # Handle regular media sources
 | 
						|
            media = await media_source.async_resolve_media(hass, media_content_id, None)
 | 
						|
            if media.path is None:
 | 
						|
                raise HomeAssistantError(
 | 
						|
                    "Only local attachments are currently supported"
 | 
						|
                )
 | 
						|
            resolved_attachments.append(
 | 
						|
                conversation.Attachment(
 | 
						|
                    media_content_id=media_content_id,
 | 
						|
                    mime_type=media.mime_type,
 | 
						|
                    path=media.path,
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
    if not created_files:
 | 
						|
        return resolved_attachments
 | 
						|
 | 
						|
    def cleanup_files() -> None:
 | 
						|
        """Cleanup temporary files."""
 | 
						|
        for file in created_files:
 | 
						|
            file.unlink(missing_ok=True)
 | 
						|
 | 
						|
    @callback
 | 
						|
    def cleanup_files_callback() -> None:
 | 
						|
        """Cleanup temporary files."""
 | 
						|
        hass.async_add_executor_job(cleanup_files)
 | 
						|
 | 
						|
    session.async_on_cleanup(cleanup_files_callback)
 | 
						|
 | 
						|
    return resolved_attachments
 | 
						|
 | 
						|
 | 
						|
async def async_generate_data(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    *,
 | 
						|
    task_name: str,
 | 
						|
    entity_id: str | None = None,
 | 
						|
    instructions: str,
 | 
						|
    structure: vol.Schema | None = None,
 | 
						|
    attachments: list[dict] | None = None,
 | 
						|
    llm_api: llm.API | None = None,
 | 
						|
) -> GenDataTaskResult:
 | 
						|
    """Run a data generation task in the AI Task integration."""
 | 
						|
    if entity_id is None:
 | 
						|
        entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
 | 
						|
 | 
						|
    if entity_id is None:
 | 
						|
        raise HomeAssistantError("No entity_id provided and no preferred entity set")
 | 
						|
 | 
						|
    entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
 | 
						|
    if entity is None:
 | 
						|
        raise HomeAssistantError(f"AI Task entity {entity_id} not found")
 | 
						|
 | 
						|
    if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
 | 
						|
        raise HomeAssistantError(
 | 
						|
            f"AI Task entity {entity_id} does not support generating data"
 | 
						|
        )
 | 
						|
 | 
						|
    if (
 | 
						|
        attachments
 | 
						|
        and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
 | 
						|
    ):
 | 
						|
        raise HomeAssistantError(
 | 
						|
            f"AI Task entity {entity_id} does not support attachments"
 | 
						|
        )
 | 
						|
 | 
						|
    with async_get_chat_session(hass) as session:
 | 
						|
        resolved_attachments = await _resolve_attachments(hass, session, attachments)
 | 
						|
 | 
						|
        return await entity.internal_async_generate_data(
 | 
						|
            session,
 | 
						|
            GenDataTask(
 | 
						|
                name=task_name,
 | 
						|
                instructions=instructions,
 | 
						|
                structure=structure,
 | 
						|
                attachments=resolved_attachments or None,
 | 
						|
                llm_api=llm_api,
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
async def async_generate_image(
 | 
						|
    hass: HomeAssistant,
 | 
						|
    *,
 | 
						|
    task_name: str,
 | 
						|
    entity_id: str | None = None,
 | 
						|
    instructions: str,
 | 
						|
    attachments: list[dict] | None = None,
 | 
						|
) -> ServiceResponse:
 | 
						|
    """Run an image generation task in the AI Task integration."""
 | 
						|
    if entity_id is None:
 | 
						|
        entity_id = hass.data[DATA_PREFERENCES].gen_image_entity_id
 | 
						|
 | 
						|
    if entity_id is None:
 | 
						|
        raise HomeAssistantError("No entity_id provided and no preferred entity set")
 | 
						|
 | 
						|
    entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
 | 
						|
    if entity is None:
 | 
						|
        raise HomeAssistantError(f"AI Task entity {entity_id} not found")
 | 
						|
 | 
						|
    if AITaskEntityFeature.GENERATE_IMAGE not in entity.supported_features:
 | 
						|
        raise HomeAssistantError(
 | 
						|
            f"AI Task entity {entity_id} does not support generating images"
 | 
						|
        )
 | 
						|
 | 
						|
    if (
 | 
						|
        attachments
 | 
						|
        and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
 | 
						|
    ):
 | 
						|
        raise HomeAssistantError(
 | 
						|
            f"AI Task entity {entity_id} does not support attachments"
 | 
						|
        )
 | 
						|
 | 
						|
    with async_get_chat_session(hass) as session:
 | 
						|
        resolved_attachments = await _resolve_attachments(hass, session, attachments)
 | 
						|
 | 
						|
        task_result = await entity.internal_async_generate_image(
 | 
						|
            session,
 | 
						|
            GenImageTask(
 | 
						|
                name=task_name,
 | 
						|
                instructions=instructions,
 | 
						|
                attachments=resolved_attachments or None,
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
    service_result = task_result.as_dict()
 | 
						|
    image_data = service_result.pop("image_data")
 | 
						|
    if service_result.get("revised_prompt") is None:
 | 
						|
        service_result["revised_prompt"] = instructions
 | 
						|
 | 
						|
    source = hass.data[DATA_MEDIA_SOURCE]
 | 
						|
 | 
						|
    current_time = datetime.now()
 | 
						|
    ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
 | 
						|
    sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
 | 
						|
 | 
						|
    image_file = ImageData(
 | 
						|
        filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}",
 | 
						|
        file=io.BytesIO(image_data),
 | 
						|
        content_type=task_result.mime_type,
 | 
						|
    )
 | 
						|
 | 
						|
    target_folder = media_source.MediaSourceItem.from_uri(
 | 
						|
        hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None
 | 
						|
    )
 | 
						|
 | 
						|
    service_result["media_source_id"] = await source.async_upload_media(
 | 
						|
        target_folder, image_file
 | 
						|
    )
 | 
						|
 | 
						|
    item = media_source.MediaSourceItem.from_uri(
 | 
						|
        hass, service_result["media_source_id"], None
 | 
						|
    )
 | 
						|
    service_result["url"] = async_sign_path(
 | 
						|
        hass,
 | 
						|
        (await source.async_resolve_media(item)).url,
 | 
						|
        timedelta(seconds=IMAGE_EXPIRY_TIME),
 | 
						|
    )
 | 
						|
 | 
						|
    return service_result
 | 
						|
 | 
						|
 | 
						|
@dataclass(slots=True)
 | 
						|
class GenDataTask:
 | 
						|
    """Gen data task to be processed."""
 | 
						|
 | 
						|
    name: str
 | 
						|
    """Name of the task."""
 | 
						|
 | 
						|
    instructions: str
 | 
						|
    """Instructions on what needs to be done."""
 | 
						|
 | 
						|
    structure: vol.Schema | None = None
 | 
						|
    """Optional structure for the data to be generated."""
 | 
						|
 | 
						|
    attachments: list[conversation.Attachment] | None = None
 | 
						|
    """List of attachments to go along the instructions."""
 | 
						|
 | 
						|
    llm_api: llm.API | None = None
 | 
						|
    """API to provide to the LLM."""
 | 
						|
 | 
						|
    def __str__(self) -> str:
 | 
						|
        """Return task as a string."""
 | 
						|
        return f"<GenDataTask {self.name}: {id(self)}>"
 | 
						|
 | 
						|
 | 
						|
@dataclass(slots=True)
 | 
						|
class GenDataTaskResult:
 | 
						|
    """Result of gen data task."""
 | 
						|
 | 
						|
    conversation_id: str
 | 
						|
    """Unique identifier for the conversation."""
 | 
						|
 | 
						|
    data: Any
 | 
						|
    """Data generated by the task."""
 | 
						|
 | 
						|
    def as_dict(self) -> dict[str, Any]:
 | 
						|
        """Return result as a dict."""
 | 
						|
        return {
 | 
						|
            "conversation_id": self.conversation_id,
 | 
						|
            "data": self.data,
 | 
						|
        }
 | 
						|
 | 
						|
 | 
						|
@dataclass(slots=True)
 | 
						|
class GenImageTask:
 | 
						|
    """Gen image task to be processed."""
 | 
						|
 | 
						|
    name: str
 | 
						|
    """Name of the task."""
 | 
						|
 | 
						|
    instructions: str
 | 
						|
    """Instructions on what needs to be done."""
 | 
						|
 | 
						|
    attachments: list[conversation.Attachment] | None = None
 | 
						|
    """List of attachments to go along the instructions."""
 | 
						|
 | 
						|
    def __str__(self) -> str:
 | 
						|
        """Return task as a string."""
 | 
						|
        return f"<GenImageTask {self.name}: {id(self)}>"
 | 
						|
 | 
						|
 | 
						|
@dataclass(slots=True)
 | 
						|
class GenImageTaskResult:
 | 
						|
    """Result of gen image task."""
 | 
						|
 | 
						|
    image_data: bytes
 | 
						|
    """Raw image data generated by the model."""
 | 
						|
 | 
						|
    conversation_id: str
 | 
						|
    """Unique identifier for the conversation."""
 | 
						|
 | 
						|
    mime_type: str
 | 
						|
    """MIME type of the generated image."""
 | 
						|
 | 
						|
    width: int | None = None
 | 
						|
    """Width of the generated image, if available."""
 | 
						|
 | 
						|
    height: int | None = None
 | 
						|
    """Height of the generated image, if available."""
 | 
						|
 | 
						|
    model: str | None = None
 | 
						|
    """Model used to generate the image, if available."""
 | 
						|
 | 
						|
    revised_prompt: str | None = None
 | 
						|
    """Revised prompt used to generate the image, if applicable."""
 | 
						|
 | 
						|
    def as_dict(self) -> dict[str, Any]:
 | 
						|
        """Return result as a dict."""
 | 
						|
        return {
 | 
						|
            "image_data": self.image_data,
 | 
						|
            "conversation_id": self.conversation_id,
 | 
						|
            "mime_type": self.mime_type,
 | 
						|
            "width": self.width,
 | 
						|
            "height": self.height,
 | 
						|
            "model": self.model,
 | 
						|
            "revised_prompt": self.revised_prompt,
 | 
						|
        }
 | 
						|
 | 
						|
 | 
						|
@dataclass(slots=True)
 | 
						|
class ImageData:
 | 
						|
    """Implementation of media_source.local_source.UploadedFile protocol."""
 | 
						|
 | 
						|
    filename: str
 | 
						|
    file: io.IOBase
 | 
						|
    content_type: str
 |