Allow AI Task to handle camera attachments (#148753)

This commit is contained in:
Paulus Schoutsen 2025-07-15 08:51:08 +02:00 committed by GitHub
parent 816977dd75
commit e2cc51f21d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 167 additions and 28 deletions

View File

@ -13,7 +13,7 @@ from homeassistant.components.conversation import (
)
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.helpers import llm
from homeassistant.helpers.chat_session import async_get_chat_session
from homeassistant.helpers.chat_session import ChatSession
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
@ -56,12 +56,12 @@ class AITaskEntity(RestoreEntity):
@contextlib.asynccontextmanager
async def _async_get_ai_task_chat_log(
self,
session: ChatSession,
task: GenDataTask,
) -> AsyncGenerator[ChatLog]:
"""Context manager used to manage the ChatLog used during an AI Task."""
# pylint: disable-next=contextmanager-generator-missing-cleanup
with (
async_get_chat_session(self.hass) as session,
async_get_chat_log(
self.hass,
session,
@ -88,12 +88,13 @@ class AITaskEntity(RestoreEntity):
@final
async def internal_async_generate_data(
self,
session: ChatSession,
task: GenDataTask,
) -> GenDataTaskResult:
"""Run a gen data task."""
self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state()
async with self._async_get_ai_task_chat_log(task) as chat_log:
async with self._async_get_ai_task_chat_log(session, task) as chat_log:
return await self._async_generate_data(task, chat_log)
async def _async_generate_data(

View File

@ -1,6 +1,7 @@
{
"domain": "ai_task",
"name": "AI Task",
"after_dependencies": ["camera"],
"codeowners": ["@home-assistant/core"],
"dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task",

View File

@ -3,17 +3,32 @@
from __future__ import annotations
from dataclasses import dataclass
import mimetypes
from pathlib import Path
import tempfile
from typing import Any
import voluptuous as vol
from homeassistant.components import conversation, media_source
from homeassistant.core import HomeAssistant
from homeassistant.components import camera, conversation, media_source
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.chat_session import async_get_chat_session
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
def _save_camera_snapshot(image: camera.Image) -> Path:
"""Save camera snapshot to temp file."""
with tempfile.NamedTemporaryFile(
mode="wb",
suffix=mimetypes.guess_extension(image.content_type, False),
delete=False,
) as temp_file:
temp_file.write(image.content)
return Path(temp_file.name)
async def async_generate_data(
hass: HomeAssistant,
*,
@ -40,41 +55,79 @@ async def async_generate_data(
)
# Resolve attachments
resolved_attachments: list[conversation.Attachment] | None = None
resolved_attachments: list[conversation.Attachment] = []
created_files: list[Path] = []
if attachments:
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
for attachment in attachments or []:
media_content_id = attachment["media_content_id"]
# Special case for camera media sources
if media_content_id.startswith("media-source://camera/"):
# Extract entity_id from the media content ID
entity_id = media_content_id.removeprefix("media-source://camera/")
# Get snapshot from camera
image = await camera.async_get_image(hass, entity_id)
temp_filename = await hass.async_add_executor_job(
_save_camera_snapshot, image
)
created_files.append(temp_filename)
resolved_attachments = []
for attachment in attachments:
media = await media_source.async_resolve_media(
hass, attachment["media_content_id"], None
resolved_attachments.append(
conversation.Attachment(
media_content_id=media_content_id,
mime_type=image.content_type,
path=temp_filename,
)
)
else:
# Handle regular media sources
media = await media_source.async_resolve_media(hass, media_content_id, None)
if media.path is None:
raise HomeAssistantError(
"Only local attachments are currently supported"
)
resolved_attachments.append(
conversation.Attachment(
media_content_id=attachment["media_content_id"],
url=media.url,
media_content_id=media_content_id,
mime_type=media.mime_type,
path=media.path,
)
)
return await entity.internal_async_generate_data(
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
attachments=resolved_attachments,
with async_get_chat_session(hass) as session:
if created_files:
def cleanup_files() -> None:
"""Cleanup temporary files."""
for file in created_files:
file.unlink(missing_ok=True)
@callback
def cleanup_files_callback() -> None:
"""Cleanup temporary files."""
hass.async_add_executor_job(cleanup_files)
session.async_on_cleanup(cleanup_files_callback)
return await entity.internal_async_generate_data(
session,
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
attachments=resolved_attachments or None,
),
)
)
@dataclass(slots=True)

View File

@ -147,9 +147,6 @@ class Attachment:
media_content_id: str
"""Media content ID of the attachment."""
url: str
"""URL of the attachment."""
mime_type: str
"""MIME type of the attachment."""

View File

@ -117,7 +117,6 @@ async def test_generate_data_service(
for msg_attachment, attachment in zip(
msg_attachments, task.attachments or [], strict=False
):
assert attachment.url == "http://example.com/media.mp4"
assert attachment.mime_type == "video/mp4"
assert attachment.media_content_id == msg_attachment["media_content_id"]
assert attachment.path == Path("media.mp4")

View File

@ -1,18 +1,26 @@
"""Test tasks for the AI Task integration."""
from datetime import timedelta
from pathlib import Path
from unittest.mock import patch
from freezegun import freeze_time
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
from homeassistant.components.camera import Image
from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session
from homeassistant.util import dt as dt_util
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
from tests.common import async_fire_time_changed
from tests.typing import WebSocketGenerator
@ -154,3 +162,83 @@ async def test_generate_data_attachments_not_supported(
}
],
)
async def test_generate_data_mixed_attachments(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test generating data with both camera and regular media source attachments."""
with (
patch(
"homeassistant.components.camera.async_get_image",
return_value=Image(content_type="image/jpeg", content=b"fake_camera_jpeg"),
) as mock_get_image,
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=media_source.PlayMedia(
url="http://example.com/test.mp4",
mime_type="video/mp4",
path=Path("/media/test.mp4"),
),
) as mock_resolve_media,
):
await async_generate_data(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Analyze these files",
attachments=[
{
"media_content_id": "media-source://camera/camera.front_door",
"media_content_type": "image/jpeg",
},
{
"media_content_id": "media-source://media_player/video.mp4",
"media_content_type": "video/mp4",
},
],
)
# Verify both methods were called
mock_get_image.assert_called_once_with(hass, "camera.front_door")
mock_resolve_media.assert_called_once_with(
hass, "media-source://media_player/video.mp4", None
)
# Check attachments
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.attachments is not None
assert len(task.attachments) == 2
# Check camera attachment
camera_attachment = task.attachments[0]
assert (
camera_attachment.media_content_id == "media-source://camera/camera.front_door"
)
assert camera_attachment.mime_type == "image/jpeg"
assert isinstance(camera_attachment.path, Path)
assert camera_attachment.path.suffix == ".jpg"
# Verify camera snapshot content
assert camera_attachment.path.exists()
content = await hass.async_add_executor_job(camera_attachment.path.read_bytes)
assert content == b"fake_camera_jpeg"
# Trigger clean up
async_fire_time_changed(
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
)
await hass.async_block_till_done()
# Verify the temporary file cleaned up
assert not camera_attachment.path.exists()
# Check regular media attachment
media_attachment = task.attachments[1]
assert media_attachment.media_content_id == "media-source://media_player/video.mp4"
assert media_attachment.mime_type == "video/mp4"
assert media_attachment.path == Path("/media/test.mp4")