Allow AI Task to handle camera attachments (#148753)

This commit is contained in:
Paulus Schoutsen 2025-07-15 08:51:08 +02:00 committed by GitHub
parent 816977dd75
commit e2cc51f21d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 167 additions and 28 deletions

View File

@ -13,7 +13,7 @@ from homeassistant.components.conversation import (
) )
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.helpers import llm from homeassistant.helpers import llm
from homeassistant.helpers.chat_session import async_get_chat_session from homeassistant.helpers.chat_session import ChatSession
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -56,12 +56,12 @@ class AITaskEntity(RestoreEntity):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _async_get_ai_task_chat_log( async def _async_get_ai_task_chat_log(
self, self,
session: ChatSession,
task: GenDataTask, task: GenDataTask,
) -> AsyncGenerator[ChatLog]: ) -> AsyncGenerator[ChatLog]:
"""Context manager used to manage the ChatLog used during an AI Task.""" """Context manager used to manage the ChatLog used during an AI Task."""
# pylint: disable-next=contextmanager-generator-missing-cleanup # pylint: disable-next=contextmanager-generator-missing-cleanup
with ( with (
async_get_chat_session(self.hass) as session,
async_get_chat_log( async_get_chat_log(
self.hass, self.hass,
session, session,
@ -88,12 +88,13 @@ class AITaskEntity(RestoreEntity):
@final @final
async def internal_async_generate_data( async def internal_async_generate_data(
self, self,
session: ChatSession,
task: GenDataTask, task: GenDataTask,
) -> GenDataTaskResult: ) -> GenDataTaskResult:
"""Run a gen data task.""" """Run a gen data task."""
self.__last_activity = dt_util.utcnow().isoformat() self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state() self.async_write_ha_state()
async with self._async_get_ai_task_chat_log(task) as chat_log: async with self._async_get_ai_task_chat_log(session, task) as chat_log:
return await self._async_generate_data(task, chat_log) return await self._async_generate_data(task, chat_log)
async def _async_generate_data( async def _async_generate_data(

View File

@ -1,6 +1,7 @@
{ {
"domain": "ai_task", "domain": "ai_task",
"name": "AI Task", "name": "AI Task",
"after_dependencies": ["camera"],
"codeowners": ["@home-assistant/core"], "codeowners": ["@home-assistant/core"],
"dependencies": ["conversation", "media_source"], "dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task", "documentation": "https://www.home-assistant.io/integrations/ai_task",

View File

@ -3,17 +3,32 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import mimetypes
from pathlib import Path
import tempfile
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation, media_source from homeassistant.components import camera, conversation, media_source
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.chat_session import async_get_chat_session
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
def _save_camera_snapshot(image: camera.Image) -> Path:
"""Save camera snapshot to temp file."""
with tempfile.NamedTemporaryFile(
mode="wb",
suffix=mimetypes.guess_extension(image.content_type, False),
delete=False,
) as temp_file:
temp_file.write(image.content)
return Path(temp_file.name)
async def async_generate_data( async def async_generate_data(
hass: HomeAssistant, hass: HomeAssistant,
*, *,
@ -40,40 +55,78 @@ async def async_generate_data(
) )
# Resolve attachments # Resolve attachments
resolved_attachments: list[conversation.Attachment] | None = None resolved_attachments: list[conversation.Attachment] = []
created_files: list[Path] = []
if attachments: if (
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features: attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError( raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments" f"AI Task entity {entity_id} does not support attachments"
) )
resolved_attachments = [] for attachment in attachments or []:
media_content_id = attachment["media_content_id"]
for attachment in attachments: # Special case for camera media sources
media = await media_source.async_resolve_media( if media_content_id.startswith("media-source://camera/"):
hass, attachment["media_content_id"], None # Extract entity_id from the media content ID
entity_id = media_content_id.removeprefix("media-source://camera/")
# Get snapshot from camera
image = await camera.async_get_image(hass, entity_id)
temp_filename = await hass.async_add_executor_job(
_save_camera_snapshot, image
) )
created_files.append(temp_filename)
resolved_attachments.append(
conversation.Attachment(
media_content_id=media_content_id,
mime_type=image.content_type,
path=temp_filename,
)
)
else:
# Handle regular media sources
media = await media_source.async_resolve_media(hass, media_content_id, None)
if media.path is None: if media.path is None:
raise HomeAssistantError( raise HomeAssistantError(
"Only local attachments are currently supported" "Only local attachments are currently supported"
) )
resolved_attachments.append( resolved_attachments.append(
conversation.Attachment( conversation.Attachment(
media_content_id=attachment["media_content_id"], media_content_id=media_content_id,
url=media.url,
mime_type=media.mime_type, mime_type=media.mime_type,
path=media.path, path=media.path,
) )
) )
with async_get_chat_session(hass) as session:
if created_files:
def cleanup_files() -> None:
"""Cleanup temporary files."""
for file in created_files:
file.unlink(missing_ok=True)
@callback
def cleanup_files_callback() -> None:
"""Cleanup temporary files."""
hass.async_add_executor_job(cleanup_files)
session.async_on_cleanup(cleanup_files_callback)
return await entity.internal_async_generate_data( return await entity.internal_async_generate_data(
session,
GenDataTask( GenDataTask(
name=task_name, name=task_name,
instructions=instructions, instructions=instructions,
structure=structure, structure=structure,
attachments=resolved_attachments, attachments=resolved_attachments or None,
) ),
) )

View File

@ -147,9 +147,6 @@ class Attachment:
media_content_id: str media_content_id: str
"""Media content ID of the attachment.""" """Media content ID of the attachment."""
url: str
"""URL of the attachment."""
mime_type: str mime_type: str
"""MIME type of the attachment.""" """MIME type of the attachment."""

View File

@ -117,7 +117,6 @@ async def test_generate_data_service(
for msg_attachment, attachment in zip( for msg_attachment, attachment in zip(
msg_attachments, task.attachments or [], strict=False msg_attachments, task.attachments or [], strict=False
): ):
assert attachment.url == "http://example.com/media.mp4"
assert attachment.mime_type == "video/mp4" assert attachment.mime_type == "video/mp4"
assert attachment.media_content_id == msg_attachment["media_content_id"] assert attachment.media_content_id == msg_attachment["media_content_id"]
assert attachment.path == Path("media.mp4") assert attachment.path == Path("media.mp4")

View File

@ -1,18 +1,26 @@
"""Test tasks for the AI Task integration.""" """Test tasks for the AI Task integration."""
from datetime import timedelta
from pathlib import Path
from unittest.mock import patch
from freezegun import freeze_time from freezegun import freeze_time
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
from homeassistant.components.camera import Image
from homeassistant.components.conversation import async_get_chat_log from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session from homeassistant.helpers import chat_session
from homeassistant.util import dt as dt_util
from .conftest import TEST_ENTITY_ID, MockAITaskEntity from .conftest import TEST_ENTITY_ID, MockAITaskEntity
from tests.common import async_fire_time_changed
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@ -154,3 +162,83 @@ async def test_generate_data_attachments_not_supported(
} }
], ],
) )
async def test_generate_data_mixed_attachments(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test generating data with both camera and regular media source attachments."""
with (
patch(
"homeassistant.components.camera.async_get_image",
return_value=Image(content_type="image/jpeg", content=b"fake_camera_jpeg"),
) as mock_get_image,
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=media_source.PlayMedia(
url="http://example.com/test.mp4",
mime_type="video/mp4",
path=Path("/media/test.mp4"),
),
) as mock_resolve_media,
):
await async_generate_data(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Analyze these files",
attachments=[
{
"media_content_id": "media-source://camera/camera.front_door",
"media_content_type": "image/jpeg",
},
{
"media_content_id": "media-source://media_player/video.mp4",
"media_content_type": "video/mp4",
},
],
)
# Verify both methods were called
mock_get_image.assert_called_once_with(hass, "camera.front_door")
mock_resolve_media.assert_called_once_with(
hass, "media-source://media_player/video.mp4", None
)
# Check attachments
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.attachments is not None
assert len(task.attachments) == 2
# Check camera attachment
camera_attachment = task.attachments[0]
assert (
camera_attachment.media_content_id == "media-source://camera/camera.front_door"
)
assert camera_attachment.mime_type == "image/jpeg"
assert isinstance(camera_attachment.path, Path)
assert camera_attachment.path.suffix == ".jpg"
# Verify camera snapshot content
assert camera_attachment.path.exists()
content = await hass.async_add_executor_job(camera_attachment.path.read_bytes)
assert content == b"fake_camera_jpeg"
# Trigger clean up
async_fire_time_changed(
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
)
await hass.async_block_till_done()
# Verify the temporary file cleaned up
assert not camera_attachment.path.exists()
# Check regular media attachment
media_attachment = task.attachments[1]
assert media_attachment.media_content_id == "media-source://media_player/video.mp4"
assert media_attachment.mime_type == "video/mp4"
assert media_attachment.path == Path("/media/test.mp4")