mirror of
https://github.com/home-assistant/core.git
synced 2026-04-06 23:47:33 +00:00
Save AI generated images to files (#152231)
This commit is contained in:
@@ -3,10 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
|
||||
from homeassistant.core import (
|
||||
@@ -28,7 +26,6 @@ from .const import (
|
||||
ATTR_STRUCTURE,
|
||||
ATTR_TASK_NAME,
|
||||
DATA_COMPONENT,
|
||||
DATA_IMAGES,
|
||||
DATA_PREFERENCES,
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_DATA,
|
||||
@@ -42,7 +39,6 @@ from .task import (
|
||||
GenDataTaskResult,
|
||||
GenImageTask,
|
||||
GenImageTaskResult,
|
||||
ImageData,
|
||||
async_generate_data,
|
||||
async_generate_image,
|
||||
)
|
||||
@@ -55,7 +51,6 @@ __all__ = [
|
||||
"GenDataTaskResult",
|
||||
"GenImageTask",
|
||||
"GenImageTaskResult",
|
||||
"ImageData",
|
||||
"async_generate_data",
|
||||
"async_generate_image",
|
||||
"async_setup",
|
||||
@@ -94,10 +89,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
|
||||
hass.data[DATA_COMPONENT] = entity_component
|
||||
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
||||
hass.data[DATA_IMAGES] = {}
|
||||
await hass.data[DATA_PREFERENCES].async_load()
|
||||
async_setup_http(hass)
|
||||
hass.http.register_view(ImageView)
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_DATA,
|
||||
@@ -209,28 +202,3 @@ class AITaskPreferences:
|
||||
def as_dict(self) -> dict[str, str | None]:
|
||||
"""Get the current preferences."""
|
||||
return {key: getattr(self, key) for key in self.KEYS}
|
||||
|
||||
|
||||
class ImageView(HomeAssistantView):
|
||||
"""View to generated images."""
|
||||
|
||||
url = f"/api/{DOMAIN}/images/{{filename}}"
|
||||
name = f"api:{DOMAIN}/images"
|
||||
|
||||
async def get(
|
||||
self,
|
||||
request: web.Request,
|
||||
filename: str,
|
||||
) -> web.Response:
|
||||
"""Serve image."""
|
||||
hass = request.app[KEY_HASS]
|
||||
image_storage = hass.data[DATA_IMAGES]
|
||||
image_data = image_storage.get(filename)
|
||||
|
||||
if image_data is None:
|
||||
raise web.HTTPNotFound
|
||||
|
||||
return web.Response(
|
||||
body=image_data.data,
|
||||
content_type=image_data.mime_type,
|
||||
)
|
||||
|
||||
@@ -8,19 +8,19 @@ from typing import TYPE_CHECKING, Final
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.components.media_source import local_source
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from . import AITaskPreferences
|
||||
from .entity import AITaskEntity
|
||||
from .task import ImageData
|
||||
|
||||
DOMAIN = "ai_task"
|
||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||
DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images")
|
||||
DATA_MEDIA_SOURCE: HassKey[local_source.LocalSource] = HassKey(f"{DOMAIN}_media_source")
|
||||
|
||||
IMAGE_DIR: Final = "image"
|
||||
IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour
|
||||
MAX_IMAGES = 20
|
||||
|
||||
SERVICE_GENERATE_DATA = "generate_data"
|
||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"domain": "ai_task",
|
||||
"name": "AI Task",
|
||||
"after_dependencies": ["camera", "http"],
|
||||
"after_dependencies": ["camera"],
|
||||
"codeowners": ["@home-assistant/core"],
|
||||
"dependencies": ["conversation", "media_source"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
||||
|
||||
@@ -2,89 +2,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
|
||||
from homeassistant.components.http.auth import async_sign_path
|
||||
from homeassistant.components.media_player import BrowseError, MediaClass
|
||||
from homeassistant.components.media_source import (
|
||||
BrowseMediaSource,
|
||||
MediaSource,
|
||||
MediaSourceItem,
|
||||
PlayMedia,
|
||||
Unresolvable,
|
||||
)
|
||||
from homeassistant.components.media_source import MediaSource, local_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DATA_IMAGES, DOMAIN, IMAGE_EXPIRY_TIME
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR
|
||||
|
||||
|
||||
async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource:
|
||||
"""Set up image media source."""
|
||||
_LOGGER.debug("Setting up image media source")
|
||||
return ImageMediaSource(hass)
|
||||
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
|
||||
"""Set up local media source."""
|
||||
media_dir = hass.config.path(f"{DOMAIN}/{IMAGE_DIR}")
|
||||
|
||||
|
||||
class ImageMediaSource(MediaSource):
|
||||
"""Provide images as media sources."""
|
||||
|
||||
name: str = "AI Generated Images"
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize ImageMediaSource."""
|
||||
super().__init__(DOMAIN)
|
||||
self.hass = hass
|
||||
|
||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||
"""Resolve media to a url."""
|
||||
image_storage = self.hass.data[DATA_IMAGES]
|
||||
image = image_storage.get(item.identifier)
|
||||
|
||||
if image is None:
|
||||
raise Unresolvable(f"Could not resolve media item: {item.identifier}")
|
||||
|
||||
return PlayMedia(
|
||||
async_sign_path(
|
||||
self.hass,
|
||||
f"/api/{DOMAIN}/images/{item.identifier}",
|
||||
timedelta(seconds=IMAGE_EXPIRY_TIME or 1800),
|
||||
),
|
||||
image.mime_type,
|
||||
)
|
||||
|
||||
async def async_browse_media(
|
||||
self,
|
||||
item: MediaSourceItem,
|
||||
) -> BrowseMediaSource:
|
||||
"""Return media."""
|
||||
if item.identifier:
|
||||
raise BrowseError("Unknown item")
|
||||
|
||||
image_storage = self.hass.data[DATA_IMAGES]
|
||||
|
||||
children = [
|
||||
BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=filename,
|
||||
media_class=MediaClass.IMAGE,
|
||||
media_content_type=image.mime_type,
|
||||
title=image.title or filename,
|
||||
can_play=True,
|
||||
can_expand=False,
|
||||
)
|
||||
for filename, image in image_storage.items()
|
||||
]
|
||||
|
||||
return BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=None,
|
||||
media_class=MediaClass.APP,
|
||||
media_content_type="",
|
||||
title="AI Generated Images",
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
children_media_class=MediaClass.IMAGE,
|
||||
children=children,
|
||||
)
|
||||
hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource(
|
||||
hass,
|
||||
DOMAIN,
|
||||
"AI Generated Images",
|
||||
{IMAGE_DIR: media_dir},
|
||||
f"/{DOMAIN}",
|
||||
)
|
||||
return source
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import io
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
@@ -18,16 +18,15 @@ from homeassistant.core import HomeAssistant, ServiceResponse, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
|
||||
|
||||
from .const import (
|
||||
DATA_COMPONENT,
|
||||
DATA_IMAGES,
|
||||
DATA_MEDIA_SOURCE,
|
||||
DATA_PREFERENCES,
|
||||
DOMAIN,
|
||||
IMAGE_DIR,
|
||||
IMAGE_EXPIRY_TIME,
|
||||
MAX_IMAGES,
|
||||
AITaskEntityFeature,
|
||||
)
|
||||
|
||||
@@ -157,24 +156,6 @@ async def async_generate_data(
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None:
|
||||
"""Remove old images to keep the storage size under the limit."""
|
||||
if num_to_remove <= 0:
|
||||
return
|
||||
|
||||
if num_to_remove >= len(image_storage):
|
||||
image_storage.clear()
|
||||
return
|
||||
|
||||
sorted_images = sorted(
|
||||
image_storage.items(),
|
||||
key=lambda item: item[1].timestamp,
|
||||
)
|
||||
|
||||
for filename, _ in sorted_images[:num_to_remove]:
|
||||
image_storage.pop(filename, None)
|
||||
|
||||
|
||||
async def async_generate_image(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
@@ -224,36 +205,34 @@ async def async_generate_image(
|
||||
if service_result.get("revised_prompt") is None:
|
||||
service_result["revised_prompt"] = instructions
|
||||
|
||||
image_storage = hass.data[DATA_IMAGES]
|
||||
|
||||
if len(image_storage) + 1 > MAX_IMAGES:
|
||||
_cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES)
|
||||
source = hass.data[DATA_MEDIA_SOURCE]
|
||||
|
||||
current_time = datetime.now()
|
||||
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
|
||||
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
|
||||
filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}"
|
||||
|
||||
image_storage[filename] = ImageData(
|
||||
data=image_data,
|
||||
timestamp=int(current_time.timestamp()),
|
||||
mime_type=task_result.mime_type,
|
||||
title=service_result["revised_prompt"],
|
||||
image_file = ImageData(
|
||||
filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}",
|
||||
file=io.BytesIO(image_data),
|
||||
content_type=task_result.mime_type,
|
||||
)
|
||||
|
||||
def _purge_image(filename: str, now: datetime) -> None:
|
||||
"""Remove image from storage."""
|
||||
image_storage.pop(filename, None)
|
||||
target_folder = media_source.MediaSourceItem.from_uri(
|
||||
hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None
|
||||
)
|
||||
|
||||
if IMAGE_EXPIRY_TIME > 0:
|
||||
async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename))
|
||||
service_result["media_source_id"] = await source.async_upload_media(
|
||||
target_folder, image_file
|
||||
)
|
||||
|
||||
item = media_source.MediaSourceItem.from_uri(
|
||||
hass, service_result["media_source_id"], None
|
||||
)
|
||||
service_result["url"] = async_sign_path(
|
||||
hass,
|
||||
f"/api/{DOMAIN}/images/{filename}",
|
||||
timedelta(seconds=IMAGE_EXPIRY_TIME or 1800),
|
||||
(await source.async_resolve_media(item)).url,
|
||||
timedelta(seconds=IMAGE_EXPIRY_TIME),
|
||||
)
|
||||
service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}"
|
||||
|
||||
return service_result
|
||||
|
||||
@@ -358,20 +337,8 @@ class GenImageTaskResult:
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImageData:
|
||||
"""Image data for stored generated images."""
|
||||
"""Implementation of media_source.local_source.UploadedFile protocol."""
|
||||
|
||||
data: bytes
|
||||
"""Raw image data."""
|
||||
|
||||
timestamp: int
|
||||
"""Timestamp when the image was generated, as a Unix timestamp."""
|
||||
|
||||
mime_type: str
|
||||
"""MIME type of the image."""
|
||||
|
||||
title: str
|
||||
"""Title of the image, usually the prompt used to generate it."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return image data as a string."""
|
||||
return f"<ImageData {self.title}: {id(self)}>"
|
||||
filename: str
|
||||
file: io.IOBase
|
||||
content_type: str
|
||||
|
||||
@@ -26,6 +26,7 @@ EXCLUDE_FROM_BACKUP = [
|
||||
"tmp_backups/*.tar",
|
||||
"OZW_Log.txt",
|
||||
"tts/*",
|
||||
"ai_task/*",
|
||||
]
|
||||
|
||||
EXCLUDE_DATABASE_FROM_BACKUP = [
|
||||
|
||||
@@ -157,4 +157,4 @@ async def init_components(
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||
assert await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
|
||||
@@ -4,13 +4,14 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import AITaskPreferences
|
||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||
from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE, DATA_PREFERENCES
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import selector
|
||||
@@ -291,6 +292,7 @@ async def test_generate_data_service_invalid_structure(
|
||||
),
|
||||
],
|
||||
)
|
||||
@freeze_time("2025-06-14 22:59:00")
|
||||
async def test_generate_image_service(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
@@ -302,21 +304,32 @@ async def test_generate_image_service(
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
preferences.async_set_preferences(**set_preferences)
|
||||
|
||||
result = await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_image",
|
||||
{
|
||||
"task_name": "Test Image",
|
||||
"instructions": "Generate a test image",
|
||||
}
|
||||
| msg_extra,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
with patch.object(
|
||||
hass.data[DATA_MEDIA_SOURCE],
|
||||
"async_upload_media",
|
||||
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
|
||||
) as mock_upload_media:
|
||||
result = await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_image",
|
||||
{
|
||||
"task_name": "Test Image",
|
||||
"instructions": "Generate a test image",
|
||||
}
|
||||
| msg_extra,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
mock_upload_media.assert_called_once()
|
||||
assert "image_data" not in result
|
||||
assert result["media_source_id"].startswith("media-source://ai_task/images/")
|
||||
assert result["url"].startswith("/api/ai_task/images/")
|
||||
assert (
|
||||
result["media_source_id"]
|
||||
== "media-source://ai_task/image/2025-06-14_225900_test_task.png"
|
||||
)
|
||||
assert result["url"].startswith(
|
||||
"/ai_task/image/2025-06-14_225900_test_task.png?authSig="
|
||||
)
|
||||
assert result["mime_type"] == "image/png"
|
||||
assert result["model"] == "mock_model"
|
||||
assert result["revised_prompt"] == "mock_revised_prompt"
|
||||
|
||||
@@ -1,64 +1,11 @@
|
||||
"""Test ai_task media source."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import ImageData
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
@pytest.fixture(name="image_id")
|
||||
async def mock_image_generate(hass: HomeAssistant) -> str:
|
||||
"""Mock image generation and return the image_id."""
|
||||
image_storage = hass.data.setdefault("ai_task_images", {})
|
||||
filename = "2025-06-15_150640_test_task.png"
|
||||
image_storage[filename] = ImageData(
|
||||
data=b"A",
|
||||
timestamp=1750000000,
|
||||
mime_type="image/png",
|
||||
title="Mock Image",
|
||||
)
|
||||
return filename
|
||||
async def test_local_media_source(hass: HomeAssistant, init_components: None) -> None:
|
||||
"""Test that the image media source is created."""
|
||||
item = await media_source.async_browse_media(hass, "media-source://")
|
||||
|
||||
|
||||
async def test_browsing(
|
||||
hass: HomeAssistant, init_components: None, image_id: str
|
||||
) -> None:
|
||||
"""Test browsing image media source."""
|
||||
item = await media_source.async_browse_media(hass, "media-source://ai_task")
|
||||
|
||||
assert item is not None
|
||||
assert item.title == "AI Generated Images"
|
||||
assert len(item.children) == 1
|
||||
assert item.children[0].media_content_type == "image/png"
|
||||
assert item.children[0].identifier == image_id
|
||||
assert item.children[0].title == "Mock Image"
|
||||
|
||||
with pytest.raises(
|
||||
media_source.BrowseError,
|
||||
match="Unknown item",
|
||||
):
|
||||
await media_source.async_browse_media(
|
||||
hass, "media-source://ai_task/invalid_path"
|
||||
)
|
||||
|
||||
|
||||
async def test_resolving(
|
||||
hass: HomeAssistant, init_components: None, image_id: str
|
||||
) -> None:
|
||||
"""Test resolving."""
|
||||
item = await media_source.async_resolve_media(
|
||||
hass, f"media-source://ai_task/{image_id}", None
|
||||
)
|
||||
assert item is not None
|
||||
assert item.url.startswith(f"/api/ai_task/images/{image_id}?authSig=")
|
||||
assert item.mime_type == "image/png"
|
||||
|
||||
invalid_id = "aabbccddeeff"
|
||||
with pytest.raises(
|
||||
media_source.Unresolvable,
|
||||
match=f"Could not resolve media item: {invalid_id}",
|
||||
):
|
||||
await media_source.async_resolve_media(
|
||||
hass, f"media-source://ai_task/{invalid_id}", None
|
||||
)
|
||||
assert any(c.title == "AI Generated Images" for c in item.children)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Test tasks for the AI Task integration."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -11,10 +11,10 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import (
|
||||
AITaskEntityFeature,
|
||||
ImageData,
|
||||
async_generate_data,
|
||||
async_generate_image,
|
||||
)
|
||||
from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE
|
||||
from homeassistant.components.camera import Image
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
@@ -257,6 +257,7 @@ async def test_generate_data_mixed_attachments(
|
||||
assert media_attachment.path == Path("/media/test.mp4")
|
||||
|
||||
|
||||
@freeze_time("2025-06-14 22:59:00")
|
||||
async def test_generate_image(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
@@ -277,17 +278,26 @@ async def test_generate_image(
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
with patch.object(
|
||||
hass.data[DATA_MEDIA_SOURCE],
|
||||
"async_upload_media",
|
||||
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
|
||||
) as mock_upload_media:
|
||||
result = await async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
mock_upload_media.assert_called_once()
|
||||
assert "image_data" not in result
|
||||
assert result["media_source_id"].startswith("media-source://ai_task/images/")
|
||||
assert result["media_source_id"].endswith("_test_task.png")
|
||||
assert result["url"].startswith("/api/ai_task/images/")
|
||||
assert result["url"].count("_test_task.png?authSig=") == 1
|
||||
assert (
|
||||
result["media_source_id"]
|
||||
== "media-source://ai_task/image/2025-06-14_225900_test_task.png"
|
||||
)
|
||||
assert result["url"].startswith(
|
||||
"/ai_task/image/2025-06-14_225900_test_task.png?authSig="
|
||||
)
|
||||
assert result["mime_type"] == "image/png"
|
||||
assert result["model"] == "mock_model"
|
||||
assert result["revised_prompt"] == "mock_revised_prompt"
|
||||
@@ -309,40 +319,3 @@ async def test_generate_image(
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
|
||||
async def test_image_cleanup(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test image cache cleanup."""
|
||||
image_storage = hass.data.setdefault("ai_task_images", {})
|
||||
image_storage.clear()
|
||||
image_storage.update(
|
||||
{
|
||||
str(idx): ImageData(
|
||||
data=b"mock_image_data",
|
||||
timestamp=int(datetime.now().timestamp()),
|
||||
mime_type="image/png",
|
||||
title="Test Image",
|
||||
)
|
||||
for idx in range(20)
|
||||
}
|
||||
)
|
||||
assert len(image_storage) == 20
|
||||
|
||||
result = await async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
assert result["url"].split("?authSig=")[0].split("/")[-1] in image_storage
|
||||
assert len(image_storage) == 20
|
||||
|
||||
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(image_storage) == 19
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from google.genai.types import File, FileState, GenerateContentResponse
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
@@ -222,6 +223,7 @@ async def test_generate_data(
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@freeze_time("2025-06-14 22:59:00")
|
||||
async def test_generate_image(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
@@ -255,14 +257,17 @@ async def test_generate_image(
|
||||
],
|
||||
)
|
||||
|
||||
assert hass.data[ai_task.DATA_IMAGES] == {}
|
||||
|
||||
result = await ai_task.async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.google_ai_task",
|
||||
instructions="Generate a test image",
|
||||
)
|
||||
with patch.object(
|
||||
media_source.local_source.LocalSource,
|
||||
"async_upload_media",
|
||||
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
|
||||
) as mock_upload_media:
|
||||
result = await ai_task.async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.google_ai_task",
|
||||
instructions="Generate a test image",
|
||||
)
|
||||
|
||||
assert result["height"] is None
|
||||
assert result["width"] is None
|
||||
@@ -270,11 +275,11 @@ async def test_generate_image(
|
||||
assert result["mime_type"] == "image/png"
|
||||
assert result["model"] == RECOMMENDED_IMAGE_MODEL.partition("/")[-1]
|
||||
|
||||
assert len(hass.data[ai_task.DATA_IMAGES]) == 1
|
||||
image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values()))
|
||||
assert image_data.data == mock_image_data
|
||||
assert image_data.mime_type == "image/png"
|
||||
assert image_data.title == "Generate a test image"
|
||||
mock_upload_media.assert_called_once()
|
||||
image_data = mock_upload_media.call_args[0][1]
|
||||
assert image_data.file.getvalue() == mock_image_data
|
||||
assert image_data.content_type == "image/png"
|
||||
assert image_data.filename == "2025-06-14_225900_test_task.png"
|
||||
|
||||
# Verify that generate_content was called with correct parameters
|
||||
assert mock_generate_content.called
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
import httpx
|
||||
from openai import PermissionDeniedError
|
||||
import pytest
|
||||
@@ -212,6 +213,7 @@ async def test_generate_data_with_attachments(
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@freeze_time("2025-06-14 22:59:00")
|
||||
async def test_generate_image(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
@@ -241,14 +243,17 @@ async def test_generate_image(
|
||||
create_message_item(id="msg_A", text="", output_index=1),
|
||||
]
|
||||
|
||||
assert hass.data[ai_task.DATA_IMAGES] == {}
|
||||
|
||||
result = await ai_task.async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.openai_ai_task",
|
||||
instructions="Generate test image",
|
||||
)
|
||||
with patch.object(
|
||||
media_source.local_source.LocalSource,
|
||||
"async_upload_media",
|
||||
return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png",
|
||||
) as mock_upload_media:
|
||||
result = await ai_task.async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.openai_ai_task",
|
||||
instructions="Generate test image",
|
||||
)
|
||||
|
||||
assert result["height"] == 1024
|
||||
assert result["width"] == 1536
|
||||
@@ -256,11 +261,11 @@ async def test_generate_image(
|
||||
assert result["mime_type"] == "image/png"
|
||||
assert result["model"] == "gpt-image-1"
|
||||
|
||||
assert len(hass.data[ai_task.DATA_IMAGES]) == 1
|
||||
image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values()))
|
||||
assert image_data.data == b"A"
|
||||
assert image_data.mime_type == "image/png"
|
||||
assert image_data.title == "Mock revised prompt."
|
||||
mock_upload_media.assert_called_once()
|
||||
image_data = mock_upload_media.call_args[0][1]
|
||||
assert image_data.file.getvalue() == b"A"
|
||||
assert image_data.content_type == "image/png"
|
||||
assert image_data.filename == "2025-06-14_225900_test_task.png"
|
||||
|
||||
assert (
|
||||
issue_registry.async_get_issue(DOMAIN, "organization_verification_required")
|
||||
|
||||
Reference in New Issue
Block a user