From d93e0a105ad79e1fd0ddbab24bda4ae0fcbc4b3b Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Sun, 14 Sep 2025 00:37:39 +0300 Subject: [PATCH] Save AI generated images to files (#152231) --- homeassistant/components/ai_task/__init__.py | 32 ------- homeassistant/components/ai_task/const.py | 6 +- .../components/ai_task/manifest.json | 2 +- .../components/ai_task/media_source.py | 94 +++---------------- homeassistant/components/ai_task/task.py | 79 +++++----------- homeassistant/components/backup/const.py | 1 + tests/components/ai_task/conftest.py | 2 +- tests/components/ai_task/test_init.py | 41 +++++--- tests/components/ai_task/test_media_source.py | 61 +----------- tests/components/ai_task/test_task.py | 71 +++++--------- .../test_ai_task.py | 31 +++--- .../openai_conversation/test_ai_task.py | 31 +++--- 12 files changed, 131 insertions(+), 320 deletions(-) diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index daaf190fc55..767104916bf 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -3,10 +3,8 @@ import logging from typing import Any -from aiohttp import web import voluptuous as vol -from homeassistant.components.http import KEY_HASS, HomeAssistantView from homeassistant.config_entries import ConfigEntry from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR from homeassistant.core import ( @@ -28,7 +26,6 @@ from .const import ( ATTR_STRUCTURE, ATTR_TASK_NAME, DATA_COMPONENT, - DATA_IMAGES, DATA_PREFERENCES, DOMAIN, SERVICE_GENERATE_DATA, @@ -42,7 +39,6 @@ from .task import ( GenDataTaskResult, GenImageTask, GenImageTaskResult, - ImageData, async_generate_data, async_generate_image, ) @@ -55,7 +51,6 @@ __all__ = [ "GenDataTaskResult", "GenImageTask", "GenImageTaskResult", - "ImageData", "async_generate_data", "async_generate_image", "async_setup", @@ -94,10 +89,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) hass.data[DATA_COMPONENT] = entity_component hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) - hass.data[DATA_IMAGES] = {} await hass.data[DATA_PREFERENCES].async_load() async_setup_http(hass) - hass.http.register_view(ImageView) hass.services.async_register( DOMAIN, SERVICE_GENERATE_DATA, @@ -209,28 +202,3 @@ class AITaskPreferences: def as_dict(self) -> dict[str, str | None]: """Get the current preferences.""" return {key: getattr(self, key) for key in self.KEYS} - - -class ImageView(HomeAssistantView): - """View to generated images.""" - - url = f"/api/{DOMAIN}/images/{{filename}}" - name = f"api:{DOMAIN}/images" - - async def get( - self, - request: web.Request, - filename: str, - ) -> web.Response: - """Serve image.""" - hass = request.app[KEY_HASS] - image_storage = hass.data[DATA_IMAGES] - image_data = image_storage.get(filename) - - if image_data is None: - raise web.HTTPNotFound - - return web.Response( - body=image_data.data, - content_type=image_data.mime_type, - ) diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py index b62f8002ecf..978e6f3cfb9 100644 --- a/homeassistant/components/ai_task/const.py +++ b/homeassistant/components/ai_task/const.py @@ -8,19 +8,19 @@ from typing import TYPE_CHECKING, Final from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: + from homeassistant.components.media_source import local_source from homeassistant.helpers.entity_component import EntityComponent from . import AITaskPreferences from .entity import AITaskEntity - from .task import ImageData DOMAIN = "ai_task" DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") -DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images") +DATA_MEDIA_SOURCE: HassKey[local_source.LocalSource] = HassKey(f"{DOMAIN}_media_source") +IMAGE_DIR: Final = "image" IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour -MAX_IMAGES = 20 SERVICE_GENERATE_DATA = "generate_data" SERVICE_GENERATE_IMAGE = "generate_image" diff --git a/homeassistant/components/ai_task/manifest.json b/homeassistant/components/ai_task/manifest.json index 9e2eec4651d..d05faf18055 100644 --- a/homeassistant/components/ai_task/manifest.json +++ b/homeassistant/components/ai_task/manifest.json @@ -1,7 +1,7 @@ { "domain": "ai_task", "name": "AI Task", - "after_dependencies": ["camera", "http"], + "after_dependencies": ["camera"], "codeowners": ["@home-assistant/core"], "dependencies": ["conversation", "media_source"], "documentation": "https://www.home-assistant.io/integrations/ai_task", diff --git a/homeassistant/components/ai_task/media_source.py b/homeassistant/components/ai_task/media_source.py index 17995584fd7..2906acf7a2d 100644 --- a/homeassistant/components/ai_task/media_source.py +++ b/homeassistant/components/ai_task/media_source.py @@ -2,89 +2,21 @@ from __future__ import annotations -from datetime import timedelta -import logging - -from homeassistant.components.http.auth import async_sign_path -from homeassistant.components.media_player import BrowseError, MediaClass -from homeassistant.components.media_source import ( - BrowseMediaSource, - MediaSource, - MediaSourceItem, - PlayMedia, - Unresolvable, -) +from homeassistant.components.media_source import MediaSource, local_source from homeassistant.core import HomeAssistant -from .const import DATA_IMAGES, DOMAIN, IMAGE_EXPIRY_TIME - -_LOGGER = logging.getLogger(__name__) +from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR -async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource: - """Set up image media source.""" - _LOGGER.debug("Setting up image media source") - return ImageMediaSource(hass) +async def async_get_media_source(hass: HomeAssistant) -> MediaSource: + """Set up local media source.""" + media_dir = hass.config.path(f"{DOMAIN}/{IMAGE_DIR}") - -class ImageMediaSource(MediaSource): - """Provide images as media sources.""" - - name: str = "AI Generated Images" - - def __init__(self, hass: HomeAssistant) -> None: - """Initialize ImageMediaSource.""" - super().__init__(DOMAIN) - self.hass = hass - - async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: - """Resolve media to a url.""" - image_storage = self.hass.data[DATA_IMAGES] - image = image_storage.get(item.identifier) - - if image is None: - raise Unresolvable(f"Could not resolve media item: {item.identifier}") - - return PlayMedia( - async_sign_path( - self.hass, - f"/api/{DOMAIN}/images/{item.identifier}", - timedelta(seconds=IMAGE_EXPIRY_TIME or 1800), - ), - image.mime_type, - ) - - async def async_browse_media( - self, - item: MediaSourceItem, - ) -> BrowseMediaSource: - """Return media.""" - if item.identifier: - raise BrowseError("Unknown item") - - image_storage = self.hass.data[DATA_IMAGES] - - children = [ - BrowseMediaSource( - domain=DOMAIN, - identifier=filename, - media_class=MediaClass.IMAGE, - media_content_type=image.mime_type, - title=image.title or filename, - can_play=True, - can_expand=False, - ) - for filename, image in image_storage.items() - ] - - return BrowseMediaSource( - domain=DOMAIN, - identifier=None, - media_class=MediaClass.APP, - media_content_type="", - title="AI Generated Images", - can_play=False, - can_expand=True, - children_media_class=MediaClass.IMAGE, - children=children, - ) + hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource( + hass, + DOMAIN, + "AI Generated Images", + {IMAGE_DIR: media_dir}, + f"/{DOMAIN}", + ) + return source diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index 5cd57395d9d..e6d86bee978 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timedelta -from functools import partial +import io import mimetypes from pathlib import Path import tempfile @@ -18,16 +18,15 @@ from homeassistant.core import HomeAssistant, ServiceResponse, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import llm from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session -from homeassistant.helpers.event import async_call_later from homeassistant.util import RE_SANITIZE_FILENAME, slugify from .const import ( DATA_COMPONENT, - DATA_IMAGES, + DATA_MEDIA_SOURCE, DATA_PREFERENCES, DOMAIN, + IMAGE_DIR, IMAGE_EXPIRY_TIME, - MAX_IMAGES, AITaskEntityFeature, ) @@ -157,24 +156,6 @@ async def async_generate_data( ) -def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None: - """Remove old images to keep the storage size under the limit.""" - if num_to_remove <= 0: - return - - if num_to_remove >= len(image_storage): - image_storage.clear() - return - - sorted_images = sorted( - image_storage.items(), - key=lambda item: item[1].timestamp, - ) - - for filename, _ in sorted_images[:num_to_remove]: - image_storage.pop(filename, None) - - async def async_generate_image( hass: HomeAssistant, *, @@ -224,36 +205,34 @@ async def async_generate_image( if service_result.get("revised_prompt") is None: service_result["revised_prompt"] = instructions - image_storage = hass.data[DATA_IMAGES] - - if len(image_storage) + 1 > MAX_IMAGES: - _cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES) + source = hass.data[DATA_MEDIA_SOURCE] current_time = datetime.now() ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png" sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name)) - filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}" - image_storage[filename] = ImageData( - data=image_data, - timestamp=int(current_time.timestamp()), - mime_type=task_result.mime_type, - title=service_result["revised_prompt"], + image_file = ImageData( + filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}", + file=io.BytesIO(image_data), + content_type=task_result.mime_type, ) - def _purge_image(filename: str, now: datetime) -> None: - """Remove image from storage.""" - image_storage.pop(filename, None) + target_folder = media_source.MediaSourceItem.from_uri( + hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None + ) - if IMAGE_EXPIRY_TIME > 0: - async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename)) + service_result["media_source_id"] = await source.async_upload_media( + target_folder, image_file + ) + item = media_source.MediaSourceItem.from_uri( + hass, service_result["media_source_id"], None + ) service_result["url"] = async_sign_path( hass, - f"/api/{DOMAIN}/images/{filename}", - timedelta(seconds=IMAGE_EXPIRY_TIME or 1800), + (await source.async_resolve_media(item)).url, + timedelta(seconds=IMAGE_EXPIRY_TIME), ) - service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}" return service_result @@ -358,20 +337,8 @@ class GenImageTaskResult: @dataclass(slots=True) class ImageData: - """Image data for stored generated images.""" + """Implementation of media_source.local_source.UploadedFile protocol.""" - data: bytes - """Raw image data.""" - - timestamp: int - """Timestamp when the image was generated, as a Unix timestamp.""" - - mime_type: str - """MIME type of the image.""" - - title: str - """Title of the image, usually the prompt used to generate it.""" - - def __str__(self) -> str: - """Return image data as a string.""" - return f"" + filename: str + file: io.IOBase + content_type: str diff --git a/homeassistant/components/backup/const.py b/homeassistant/components/backup/const.py index 773deaef174..1cfb796bd2e 100644 --- a/homeassistant/components/backup/const.py +++ b/homeassistant/components/backup/const.py @@ -26,6 +26,7 @@ EXCLUDE_FROM_BACKUP = [ "tmp_backups/*.tar", "OZW_Log.txt", "tts/*", + "ai_task/*", ] EXCLUDE_DATABASE_FROM_BACKUP = [ diff --git a/tests/components/ai_task/conftest.py b/tests/components/ai_task/conftest.py index 06f9a56a813..ceffb7c055e 100644 --- a/tests/components/ai_task/conftest.py +++ b/tests/components/ai_task/conftest.py @@ -157,4 +157,4 @@ async def init_components( with mock_config_flow(TEST_DOMAIN, ConfigFlow): assert await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py index 5c6465936d9..83e1808b6d8 100644 --- a/tests/components/ai_task/test_init.py +++ b/tests/components/ai_task/test_init.py @@ -4,13 +4,14 @@ from pathlib import Path from typing import Any from unittest.mock import patch +from freezegun import freeze_time from freezegun.api import FrozenDateTimeFactory import pytest import voluptuous as vol from homeassistant.components import media_source from homeassistant.components.ai_task import AITaskPreferences -from homeassistant.components.ai_task.const import DATA_PREFERENCES +from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE, DATA_PREFERENCES from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import selector @@ -291,6 +292,7 @@ async def test_generate_data_service_invalid_structure( ), ], ) +@freeze_time("2025-06-14 22:59:00") async def test_generate_image_service( hass: HomeAssistant, init_components: None, @@ -302,21 +304,32 @@ async def test_generate_image_service( preferences = hass.data[DATA_PREFERENCES] preferences.async_set_preferences(**set_preferences) - result = await hass.services.async_call( - "ai_task", - "generate_image", - { - "task_name": "Test Image", - "instructions": "Generate a test image", - } - | msg_extra, - blocking=True, - return_response=True, - ) + with patch.object( + hass.data[DATA_MEDIA_SOURCE], + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await hass.services.async_call( + "ai_task", + "generate_image", + { + "task_name": "Test Image", + "instructions": "Generate a test image", + } + | msg_extra, + blocking=True, + return_response=True, + ) + mock_upload_media.assert_called_once() assert "image_data" not in result - assert result["media_source_id"].startswith("media-source://ai_task/images/") - assert result["url"].startswith("/api/ai_task/images/") + assert ( + result["media_source_id"] + == "media-source://ai_task/image/2025-06-14_225900_test_task.png" + ) + assert result["url"].startswith( + "/ai_task/image/2025-06-14_225900_test_task.png?authSig=" + ) assert result["mime_type"] == "image/png" assert result["model"] == "mock_model" assert result["revised_prompt"] == "mock_revised_prompt" diff --git a/tests/components/ai_task/test_media_source.py b/tests/components/ai_task/test_media_source.py index eae597efb91..18f1834e082 100644 --- a/tests/components/ai_task/test_media_source.py +++ b/tests/components/ai_task/test_media_source.py @@ -1,64 +1,11 @@ """Test ai_task media source.""" -import pytest - from homeassistant.components import media_source -from homeassistant.components.ai_task import ImageData from homeassistant.core import HomeAssistant -@pytest.fixture(name="image_id") -async def mock_image_generate(hass: HomeAssistant) -> str: - """Mock image generation and return the image_id.""" - image_storage = hass.data.setdefault("ai_task_images", {}) - filename = "2025-06-15_150640_test_task.png" - image_storage[filename] = ImageData( - data=b"A", - timestamp=1750000000, - mime_type="image/png", - title="Mock Image", - ) - return filename +async def test_local_media_source(hass: HomeAssistant, init_components: None) -> None: + """Test that the image media source is created.""" + item = await media_source.async_browse_media(hass, "media-source://") - -async def test_browsing( - hass: HomeAssistant, init_components: None, image_id: str -) -> None: - """Test browsing image media source.""" - item = await media_source.async_browse_media(hass, "media-source://ai_task") - - assert item is not None - assert item.title == "AI Generated Images" - assert len(item.children) == 1 - assert item.children[0].media_content_type == "image/png" - assert item.children[0].identifier == image_id - assert item.children[0].title == "Mock Image" - - with pytest.raises( - media_source.BrowseError, - match="Unknown item", - ): - await media_source.async_browse_media( - hass, "media-source://ai_task/invalid_path" - ) - - -async def test_resolving( - hass: HomeAssistant, init_components: None, image_id: str -) -> None: - """Test resolving.""" - item = await media_source.async_resolve_media( - hass, f"media-source://ai_task/{image_id}", None - ) - assert item is not None - assert item.url.startswith(f"/api/ai_task/images/{image_id}?authSig=") - assert item.mime_type == "image/png" - - invalid_id = "aabbccddeeff" - with pytest.raises( - media_source.Unresolvable, - match=f"Could not resolve media item: {invalid_id}", - ): - await media_source.async_resolve_media( - hass, f"media-source://ai_task/{invalid_id}", None - ) + assert any(c.title == "AI Generated Images" for c in item.children) diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py index bc8bff4e632..345d6c30981 100644 --- a/tests/components/ai_task/test_task.py +++ b/tests/components/ai_task/test_task.py @@ -1,6 +1,6 @@ """Test tasks for the AI Task integration.""" -from datetime import datetime, timedelta +from datetime import timedelta from pathlib import Path from unittest.mock import patch @@ -11,10 +11,10 @@ from syrupy.assertion import SnapshotAssertion from homeassistant.components import media_source from homeassistant.components.ai_task import ( AITaskEntityFeature, - ImageData, async_generate_data, async_generate_image, ) +from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE from homeassistant.components.camera import Image from homeassistant.components.conversation import async_get_chat_log from homeassistant.const import STATE_UNKNOWN @@ -257,6 +257,7 @@ async def test_generate_data_mixed_attachments( assert media_attachment.path == Path("/media/test.mp4") +@freeze_time("2025-06-14 22:59:00") async def test_generate_image( hass: HomeAssistant, init_components: None, @@ -277,17 +278,26 @@ async def test_generate_image( assert state is not None assert state.state == STATE_UNKNOWN - result = await async_generate_image( - hass, - task_name="Test Task", - entity_id=TEST_ENTITY_ID, - instructions="Test prompt", - ) + with patch.object( + hass.data[DATA_MEDIA_SOURCE], + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await async_generate_image( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + ) + mock_upload_media.assert_called_once() assert "image_data" not in result - assert result["media_source_id"].startswith("media-source://ai_task/images/") - assert result["media_source_id"].endswith("_test_task.png") - assert result["url"].startswith("/api/ai_task/images/") - assert result["url"].count("_test_task.png?authSig=") == 1 + assert ( + result["media_source_id"] + == "media-source://ai_task/image/2025-06-14_225900_test_task.png" + ) + assert result["url"].startswith( + "/ai_task/image/2025-06-14_225900_test_task.png?authSig=" + ) assert result["mime_type"] == "image/png" assert result["model"] == "mock_model" assert result["revised_prompt"] == "mock_revised_prompt" @@ -309,40 +319,3 @@ async def test_generate_image( entity_id=TEST_ENTITY_ID, instructions="Test prompt", ) - - -async def test_image_cleanup( - hass: HomeAssistant, - init_components: None, - mock_ai_task_entity: MockAITaskEntity, -) -> None: - """Test image cache cleanup.""" - image_storage = hass.data.setdefault("ai_task_images", {}) - image_storage.clear() - image_storage.update( - { - str(idx): ImageData( - data=b"mock_image_data", - timestamp=int(datetime.now().timestamp()), - mime_type="image/png", - title="Test Image", - ) - for idx in range(20) - } - ) - assert len(image_storage) == 20 - - result = await async_generate_image( - hass, - task_name="Test Task", - entity_id=TEST_ENTITY_ID, - instructions="Test prompt", - ) - - assert result["url"].split("?authSig=")[0].split("/")[-1] in image_storage - assert len(image_storage) == 20 - - async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1)) - await hass.async_block_till_done() - - assert len(image_storage) == 19 diff --git a/tests/components/google_generative_ai_conversation/test_ai_task.py b/tests/components/google_generative_ai_conversation/test_ai_task.py index 11e6864d312..25799ef4bc1 100644 --- a/tests/components/google_generative_ai_conversation/test_ai_task.py +++ b/tests/components/google_generative_ai_conversation/test_ai_task.py @@ -3,6 +3,7 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock, patch +from freezegun import freeze_time from google.genai.types import File, FileState, GenerateContentResponse import pytest import voluptuous as vol @@ -222,6 +223,7 @@ async def test_generate_data( @pytest.mark.usefixtures("mock_init_component") +@freeze_time("2025-06-14 22:59:00") async def test_generate_image( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -255,14 +257,17 @@ async def test_generate_image( ], ) - assert hass.data[ai_task.DATA_IMAGES] == {} - - result = await ai_task.async_generate_image( - hass, - task_name="Test Task", - entity_id="ai_task.google_ai_task", - instructions="Generate a test image", - ) + with patch.object( + media_source.local_source.LocalSource, + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await ai_task.async_generate_image( + hass, + task_name="Test Task", + entity_id="ai_task.google_ai_task", + instructions="Generate a test image", + ) assert result["height"] is None assert result["width"] is None @@ -270,11 +275,11 @@ async def test_generate_image( assert result["mime_type"] == "image/png" assert result["model"] == RECOMMENDED_IMAGE_MODEL.partition("/")[-1] - assert len(hass.data[ai_task.DATA_IMAGES]) == 1 - image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values())) - assert image_data.data == mock_image_data - assert image_data.mime_type == "image/png" - assert image_data.title == "Generate a test image" + mock_upload_media.assert_called_once() + image_data = mock_upload_media.call_args[0][1] + assert image_data.file.getvalue() == mock_image_data + assert image_data.content_type == "image/png" + assert image_data.filename == "2025-06-14_225900_test_task.png" # Verify that generate_content was called with correct parameters assert mock_generate_content.called diff --git a/tests/components/openai_conversation/test_ai_task.py b/tests/components/openai_conversation/test_ai_task.py index 31a9212bff2..51ac505893e 100644 --- a/tests/components/openai_conversation/test_ai_task.py +++ b/tests/components/openai_conversation/test_ai_task.py @@ -3,6 +3,7 @@ from pathlib import Path from unittest.mock import AsyncMock, patch +from freezegun import freeze_time import httpx from openai import PermissionDeniedError import pytest @@ -212,6 +213,7 @@ async def test_generate_data_with_attachments( @pytest.mark.usefixtures("mock_init_component") +@freeze_time("2025-06-14 22:59:00") async def test_generate_image( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -241,14 +243,17 @@ async def test_generate_image( create_message_item(id="msg_A", text="", output_index=1), ] - assert hass.data[ai_task.DATA_IMAGES] == {} - - result = await ai_task.async_generate_image( - hass, - task_name="Test Task", - entity_id="ai_task.openai_ai_task", - instructions="Generate test image", - ) + with patch.object( + media_source.local_source.LocalSource, + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await ai_task.async_generate_image( + hass, + task_name="Test Task", + entity_id="ai_task.openai_ai_task", + instructions="Generate test image", + ) assert result["height"] == 1024 assert result["width"] == 1536 @@ -256,11 +261,11 @@ async def test_generate_image( assert result["mime_type"] == "image/png" assert result["model"] == "gpt-image-1" - assert len(hass.data[ai_task.DATA_IMAGES]) == 1 - image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values())) - assert image_data.data == b"A" - assert image_data.mime_type == "image/png" - assert image_data.title == "Mock revised prompt." + mock_upload_media.assert_called_once() + image_data = mock_upload_media.call_args[0][1] + assert image_data.file.getvalue() == b"A" + assert image_data.content_type == "image/png" + assert image_data.filename == "2025-06-14_225900_test_task.png" assert ( issue_registry.async_get_issue(DOMAIN, "organization_verification_required")