Save AI generated images to files (#152231)

This commit is contained in:
Denis Shulyaka
2025-09-14 00:37:39 +03:00
committed by GitHub
parent ab1619c0b4
commit d93e0a105a
12 changed files with 131 additions and 320 deletions

View File

@@ -3,10 +3,8 @@
import logging
from typing import Any
from aiohttp import web
import voluptuous as vol
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import (
@@ -28,7 +26,6 @@ from .const import (
ATTR_STRUCTURE,
ATTR_TASK_NAME,
DATA_COMPONENT,
DATA_IMAGES,
DATA_PREFERENCES,
DOMAIN,
SERVICE_GENERATE_DATA,
@@ -42,7 +39,6 @@ from .task import (
GenDataTaskResult,
GenImageTask,
GenImageTaskResult,
ImageData,
async_generate_data,
async_generate_image,
)
@@ -55,7 +51,6 @@ __all__ = [
"GenDataTaskResult",
"GenImageTask",
"GenImageTaskResult",
"ImageData",
"async_generate_data",
"async_generate_image",
"async_setup",
@@ -94,10 +89,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
hass.data[DATA_COMPONENT] = entity_component
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
hass.data[DATA_IMAGES] = {}
await hass.data[DATA_PREFERENCES].async_load()
async_setup_http(hass)
hass.http.register_view(ImageView)
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_DATA,
@@ -209,28 +202,3 @@ class AITaskPreferences:
def as_dict(self) -> dict[str, str | None]:
"""Get the current preferences."""
return {key: getattr(self, key) for key in self.KEYS}
class ImageView(HomeAssistantView):
"""View to generated images."""
url = f"/api/{DOMAIN}/images/{{filename}}"
name = f"api:{DOMAIN}/images"
async def get(
self,
request: web.Request,
filename: str,
) -> web.Response:
"""Serve image."""
hass = request.app[KEY_HASS]
image_storage = hass.data[DATA_IMAGES]
image_data = image_storage.get(filename)
if image_data is None:
raise web.HTTPNotFound
return web.Response(
body=image_data.data,
content_type=image_data.mime_type,
)

View File

@@ -8,19 +8,19 @@ from typing import TYPE_CHECKING, Final
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.components.media_source import local_source
from homeassistant.helpers.entity_component import EntityComponent
from . import AITaskPreferences
from .entity import AITaskEntity
from .task import ImageData
DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images")
DATA_MEDIA_SOURCE: HassKey[local_source.LocalSource] = HassKey(f"{DOMAIN}_media_source")
IMAGE_DIR: Final = "image"
IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour
MAX_IMAGES = 20
SERVICE_GENERATE_DATA = "generate_data"
SERVICE_GENERATE_IMAGE = "generate_image"

View File

@@ -1,7 +1,7 @@
{
"domain": "ai_task",
"name": "AI Task",
"after_dependencies": ["camera", "http"],
"after_dependencies": ["camera"],
"codeowners": ["@home-assistant/core"],
"dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task",

View File

@@ -2,89 +2,21 @@
from __future__ import annotations
from datetime import timedelta
import logging
from homeassistant.components.http.auth import async_sign_path
from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.components.media_source import (
BrowseMediaSource,
MediaSource,
MediaSourceItem,
PlayMedia,
Unresolvable,
)
from homeassistant.components.media_source import MediaSource, local_source
from homeassistant.core import HomeAssistant
from .const import DATA_IMAGES, DOMAIN, IMAGE_EXPIRY_TIME
_LOGGER = logging.getLogger(__name__)
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR
async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource:
"""Set up image media source."""
_LOGGER.debug("Setting up image media source")
return ImageMediaSource(hass)
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
"""Set up local media source."""
media_dir = hass.config.path(f"{DOMAIN}/{IMAGE_DIR}")
class ImageMediaSource(MediaSource):
"""Provide images as media sources."""
name: str = "AI Generated Images"
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize ImageMediaSource."""
super().__init__(DOMAIN)
self.hass = hass
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
image_storage = self.hass.data[DATA_IMAGES]
image = image_storage.get(item.identifier)
if image is None:
raise Unresolvable(f"Could not resolve media item: {item.identifier}")
return PlayMedia(
async_sign_path(
self.hass,
f"/api/{DOMAIN}/images/{item.identifier}",
timedelta(seconds=IMAGE_EXPIRY_TIME or 1800),
),
image.mime_type,
)
async def async_browse_media(
self,
item: MediaSourceItem,
) -> BrowseMediaSource:
"""Return media."""
if item.identifier:
raise BrowseError("Unknown item")
image_storage = self.hass.data[DATA_IMAGES]
children = [
BrowseMediaSource(
domain=DOMAIN,
identifier=filename,
media_class=MediaClass.IMAGE,
media_content_type=image.mime_type,
title=image.title or filename,
can_play=True,
can_expand=False,
)
for filename, image in image_storage.items()
]
return BrowseMediaSource(
domain=DOMAIN,
identifier=None,
media_class=MediaClass.APP,
media_content_type="",
title="AI Generated Images",
can_play=False,
can_expand=True,
children_media_class=MediaClass.IMAGE,
children=children,
)
hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource(
hass,
DOMAIN,
"AI Generated Images",
{IMAGE_DIR: media_dir},
f"/{DOMAIN}",
)
return source

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import partial
import io
import mimetypes
from pathlib import Path
import tempfile
@@ -18,16 +18,15 @@ from homeassistant.core import HomeAssistant, ServiceResponse, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
from homeassistant.helpers.event import async_call_later
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
from .const import (
DATA_COMPONENT,
DATA_IMAGES,
DATA_MEDIA_SOURCE,
DATA_PREFERENCES,
DOMAIN,
IMAGE_DIR,
IMAGE_EXPIRY_TIME,
MAX_IMAGES,
AITaskEntityFeature,
)
@@ -157,24 +156,6 @@ async def async_generate_data(
)
def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None:
"""Remove old images to keep the storage size under the limit."""
if num_to_remove <= 0:
return
if num_to_remove >= len(image_storage):
image_storage.clear()
return
sorted_images = sorted(
image_storage.items(),
key=lambda item: item[1].timestamp,
)
for filename, _ in sorted_images[:num_to_remove]:
image_storage.pop(filename, None)
async def async_generate_image(
hass: HomeAssistant,
*,
@@ -224,36 +205,34 @@ async def async_generate_image(
if service_result.get("revised_prompt") is None:
service_result["revised_prompt"] = instructions
image_storage = hass.data[DATA_IMAGES]
if len(image_storage) + 1 > MAX_IMAGES:
_cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES)
source = hass.data[DATA_MEDIA_SOURCE]
current_time = datetime.now()
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}"
image_storage[filename] = ImageData(
data=image_data,
timestamp=int(current_time.timestamp()),
mime_type=task_result.mime_type,
title=service_result["revised_prompt"],
image_file = ImageData(
filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}",
file=io.BytesIO(image_data),
content_type=task_result.mime_type,
)
def _purge_image(filename: str, now: datetime) -> None:
"""Remove image from storage."""
image_storage.pop(filename, None)
target_folder = media_source.MediaSourceItem.from_uri(
hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None
)
if IMAGE_EXPIRY_TIME > 0:
async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename))
service_result["media_source_id"] = await source.async_upload_media(
target_folder, image_file
)
item = media_source.MediaSourceItem.from_uri(
hass, service_result["media_source_id"], None
)
service_result["url"] = async_sign_path(
hass,
f"/api/{DOMAIN}/images/{filename}",
timedelta(seconds=IMAGE_EXPIRY_TIME or 1800),
(await source.async_resolve_media(item)).url,
timedelta(seconds=IMAGE_EXPIRY_TIME),
)
service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}"
return service_result
@@ -358,20 +337,8 @@ class GenImageTaskResult:
@dataclass(slots=True)
class ImageData:
"""Image data for stored generated images."""
"""Implementation of media_source.local_source.UploadedFile protocol."""
data: bytes
"""Raw image data."""
timestamp: int
"""Timestamp when the image was generated, as a Unix timestamp."""
mime_type: str
"""MIME type of the image."""
title: str
"""Title of the image, usually the prompt used to generate it."""
def __str__(self) -> str:
"""Return image data as a string."""
return f"<ImageData {self.title}: {id(self)}>"
filename: str
file: io.IOBase
content_type: str

View File

@@ -26,6 +26,7 @@ EXCLUDE_FROM_BACKUP = [
"tmp_backups/*.tar",
"OZW_Log.txt",
"tts/*",
"ai_task/*",
]
EXCLUDE_DATABASE_FROM_BACKUP = [

View File

@@ -157,4 +157,4 @@ async def init_components(
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
assert await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
await hass.async_block_till_done(wait_background_tasks=True)

View File

@@ -4,13 +4,14 @@ from pathlib import Path
from typing import Any
from unittest.mock import patch
from freezegun import freeze_time
from freezegun.api import FrozenDateTimeFactory
import pytest
import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskPreferences
from homeassistant.components.ai_task.const import DATA_PREFERENCES
from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE, DATA_PREFERENCES
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import selector
@@ -291,6 +292,7 @@ async def test_generate_data_service_invalid_structure(
),
],
)
@freeze_time("2025-06-14 22:59:00")
async def test_generate_image_service(
hass: HomeAssistant,
init_components: None,
@@ -302,21 +304,32 @@ async def test_generate_image_service(
preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences)
result = await hass.services.async_call(
"ai_task",
"generate_image",
{
"task_name": "Test Image",
"instructions": "Generate a test image",
}
| msg_extra,
blocking=True,
return_response=True,
)
with patch.object(
hass.data[DATA_MEDIA_SOURCE],
"async_upload_media",
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
) as mock_upload_media:
result = await hass.services.async_call(
"ai_task",
"generate_image",
{
"task_name": "Test Image",
"instructions": "Generate a test image",
}
| msg_extra,
blocking=True,
return_response=True,
)
mock_upload_media.assert_called_once()
assert "image_data" not in result
assert result["media_source_id"].startswith("media-source://ai_task/images/")
assert result["url"].startswith("/api/ai_task/images/")
assert (
result["media_source_id"]
== "media-source://ai_task/image/2025-06-14_225900_test_task.png"
)
assert result["url"].startswith(
"/ai_task/image/2025-06-14_225900_test_task.png?authSig="
)
assert result["mime_type"] == "image/png"
assert result["model"] == "mock_model"
assert result["revised_prompt"] == "mock_revised_prompt"

View File

@@ -1,64 +1,11 @@
"""Test ai_task media source."""
import pytest
from homeassistant.components import media_source
from homeassistant.components.ai_task import ImageData
from homeassistant.core import HomeAssistant
@pytest.fixture(name="image_id")
async def mock_image_generate(hass: HomeAssistant) -> str:
"""Mock image generation and return the image_id."""
image_storage = hass.data.setdefault("ai_task_images", {})
filename = "2025-06-15_150640_test_task.png"
image_storage[filename] = ImageData(
data=b"A",
timestamp=1750000000,
mime_type="image/png",
title="Mock Image",
)
return filename
async def test_local_media_source(hass: HomeAssistant, init_components: None) -> None:
"""Test that the image media source is created."""
item = await media_source.async_browse_media(hass, "media-source://")
async def test_browsing(
hass: HomeAssistant, init_components: None, image_id: str
) -> None:
"""Test browsing image media source."""
item = await media_source.async_browse_media(hass, "media-source://ai_task")
assert item is not None
assert item.title == "AI Generated Images"
assert len(item.children) == 1
assert item.children[0].media_content_type == "image/png"
assert item.children[0].identifier == image_id
assert item.children[0].title == "Mock Image"
with pytest.raises(
media_source.BrowseError,
match="Unknown item",
):
await media_source.async_browse_media(
hass, "media-source://ai_task/invalid_path"
)
async def test_resolving(
hass: HomeAssistant, init_components: None, image_id: str
) -> None:
"""Test resolving."""
item = await media_source.async_resolve_media(
hass, f"media-source://ai_task/{image_id}", None
)
assert item is not None
assert item.url.startswith(f"/api/ai_task/images/{image_id}?authSig=")
assert item.mime_type == "image/png"
invalid_id = "aabbccddeeff"
with pytest.raises(
media_source.Unresolvable,
match=f"Could not resolve media item: {invalid_id}",
):
await media_source.async_resolve_media(
hass, f"media-source://ai_task/{invalid_id}", None
)
assert any(c.title == "AI Generated Images" for c in item.children)

View File

@@ -1,6 +1,6 @@
"""Test tasks for the AI Task integration."""
from datetime import datetime, timedelta
from datetime import timedelta
from pathlib import Path
from unittest.mock import patch
@@ -11,10 +11,10 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components import media_source
from homeassistant.components.ai_task import (
AITaskEntityFeature,
ImageData,
async_generate_data,
async_generate_image,
)
from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE
from homeassistant.components.camera import Image
from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN
@@ -257,6 +257,7 @@ async def test_generate_data_mixed_attachments(
assert media_attachment.path == Path("/media/test.mp4")
@freeze_time("2025-06-14 22:59:00")
async def test_generate_image(
hass: HomeAssistant,
init_components: None,
@@ -277,17 +278,26 @@ async def test_generate_image(
assert state is not None
assert state.state == STATE_UNKNOWN
result = await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
with patch.object(
hass.data[DATA_MEDIA_SOURCE],
"async_upload_media",
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
) as mock_upload_media:
result = await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
mock_upload_media.assert_called_once()
assert "image_data" not in result
assert result["media_source_id"].startswith("media-source://ai_task/images/")
assert result["media_source_id"].endswith("_test_task.png")
assert result["url"].startswith("/api/ai_task/images/")
assert result["url"].count("_test_task.png?authSig=") == 1
assert (
result["media_source_id"]
== "media-source://ai_task/image/2025-06-14_225900_test_task.png"
)
assert result["url"].startswith(
"/ai_task/image/2025-06-14_225900_test_task.png?authSig="
)
assert result["mime_type"] == "image/png"
assert result["model"] == "mock_model"
assert result["revised_prompt"] == "mock_revised_prompt"
@@ -309,40 +319,3 @@ async def test_generate_image(
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
async def test_image_cleanup(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test image cache cleanup."""
image_storage = hass.data.setdefault("ai_task_images", {})
image_storage.clear()
image_storage.update(
{
str(idx): ImageData(
data=b"mock_image_data",
timestamp=int(datetime.now().timestamp()),
mime_type="image/png",
title="Test Image",
)
for idx in range(20)
}
)
assert len(image_storage) == 20
result = await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert result["url"].split("?authSig=")[0].split("/")[-1] in image_storage
assert len(image_storage) == 20
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1))
await hass.async_block_till_done()
assert len(image_storage) == 19

View File

@@ -3,6 +3,7 @@
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
from freezegun import freeze_time
from google.genai.types import File, FileState, GenerateContentResponse
import pytest
import voluptuous as vol
@@ -222,6 +223,7 @@ async def test_generate_data(
@pytest.mark.usefixtures("mock_init_component")
@freeze_time("2025-06-14 22:59:00")
async def test_generate_image(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@@ -255,14 +257,17 @@ async def test_generate_image(
],
)
assert hass.data[ai_task.DATA_IMAGES] == {}
result = await ai_task.async_generate_image(
hass,
task_name="Test Task",
entity_id="ai_task.google_ai_task",
instructions="Generate a test image",
)
with patch.object(
media_source.local_source.LocalSource,
"async_upload_media",
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
) as mock_upload_media:
result = await ai_task.async_generate_image(
hass,
task_name="Test Task",
entity_id="ai_task.google_ai_task",
instructions="Generate a test image",
)
assert result["height"] is None
assert result["width"] is None
@@ -270,11 +275,11 @@ async def test_generate_image(
assert result["mime_type"] == "image/png"
assert result["model"] == RECOMMENDED_IMAGE_MODEL.partition("/")[-1]
assert len(hass.data[ai_task.DATA_IMAGES]) == 1
image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values()))
assert image_data.data == mock_image_data
assert image_data.mime_type == "image/png"
assert image_data.title == "Generate a test image"
mock_upload_media.assert_called_once()
image_data = mock_upload_media.call_args[0][1]
assert image_data.file.getvalue() == mock_image_data
assert image_data.content_type == "image/png"
assert image_data.filename == "2025-06-14_225900_test_task.png"
# Verify that generate_content was called with correct parameters
assert mock_generate_content.called

View File

@@ -3,6 +3,7 @@
from pathlib import Path
from unittest.mock import AsyncMock, patch
from freezegun import freeze_time
import httpx
from openai import PermissionDeniedError
import pytest
@@ -212,6 +213,7 @@ async def test_generate_data_with_attachments(
@pytest.mark.usefixtures("mock_init_component")
@freeze_time("2025-06-14 22:59:00")
async def test_generate_image(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
@@ -241,14 +243,17 @@ async def test_generate_image(
create_message_item(id="msg_A", text="", output_index=1),
]
assert hass.data[ai_task.DATA_IMAGES] == {}
result = await ai_task.async_generate_image(
hass,
task_name="Test Task",
entity_id="ai_task.openai_ai_task",
instructions="Generate test image",
)
with patch.object(
media_source.local_source.LocalSource,
"async_upload_media",
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
) as mock_upload_media:
result = await ai_task.async_generate_image(
hass,
task_name="Test Task",
entity_id="ai_task.openai_ai_task",
instructions="Generate test image",
)
assert result["height"] == 1024
assert result["width"] == 1536
@@ -256,11 +261,11 @@ async def test_generate_image(
assert result["mime_type"] == "image/png"
assert result["model"] == "gpt-image-1"
assert len(hass.data[ai_task.DATA_IMAGES]) == 1
image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values()))
assert image_data.data == b"A"
assert image_data.mime_type == "image/png"
assert image_data.title == "Mock revised prompt."
mock_upload_media.assert_called_once()
image_data = mock_upload_media.call_args[0][1]
assert image_data.file.getvalue() == b"A"
assert image_data.content_type == "image/png"
assert image_data.filename == "2025-06-14_225900_test_task.png"
assert (
issue_registry.async_get_issue(DOMAIN, "organization_verification_required")