mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Allow AI Task to handle camera attachments (#148753)
This commit is contained in:
parent
816977dd75
commit
e2cc51f21d
@ -13,7 +13,7 @@ from homeassistant.components.conversation import (
|
||||
)
|
||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.chat_session import async_get_chat_session
|
||||
from homeassistant.helpers.chat_session import ChatSession
|
||||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
@ -56,12 +56,12 @@ class AITaskEntity(RestoreEntity):
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_get_ai_task_chat_log(
|
||||
self,
|
||||
session: ChatSession,
|
||||
task: GenDataTask,
|
||||
) -> AsyncGenerator[ChatLog]:
|
||||
"""Context manager used to manage the ChatLog used during an AI Task."""
|
||||
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
||||
with (
|
||||
async_get_chat_session(self.hass) as session,
|
||||
async_get_chat_log(
|
||||
self.hass,
|
||||
session,
|
||||
@ -88,12 +88,13 @@ class AITaskEntity(RestoreEntity):
|
||||
@final
|
||||
async def internal_async_generate_data(
|
||||
self,
|
||||
session: ChatSession,
|
||||
task: GenDataTask,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a gen data task."""
|
||||
self.__last_activity = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
async with self._async_get_ai_task_chat_log(task) as chat_log:
|
||||
async with self._async_get_ai_task_chat_log(session, task) as chat_log:
|
||||
return await self._async_generate_data(task, chat_log)
|
||||
|
||||
async def _async_generate_data(
|
||||
|
@ -1,6 +1,7 @@
|
||||
{
|
||||
"domain": "ai_task",
|
||||
"name": "AI Task",
|
||||
"after_dependencies": ["camera"],
|
||||
"codeowners": ["@home-assistant/core"],
|
||||
"dependencies": ["conversation", "media_source"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
||||
|
@ -3,17 +3,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation, media_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.components import camera, conversation, media_source
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.chat_session import async_get_chat_session
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
|
||||
|
||||
def _save_camera_snapshot(image: camera.Image) -> Path:
|
||||
"""Save camera snapshot to temp file."""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="wb",
|
||||
suffix=mimetypes.guess_extension(image.content_type, False),
|
||||
delete=False,
|
||||
) as temp_file:
|
||||
temp_file.write(image.content)
|
||||
return Path(temp_file.name)
|
||||
|
||||
|
||||
async def async_generate_data(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
@ -40,41 +55,79 @@ async def async_generate_data(
|
||||
)
|
||||
|
||||
# Resolve attachments
|
||||
resolved_attachments: list[conversation.Attachment] | None = None
|
||||
resolved_attachments: list[conversation.Attachment] = []
|
||||
created_files: list[Path] = []
|
||||
|
||||
if attachments:
|
||||
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support attachments"
|
||||
if (
|
||||
attachments
|
||||
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support attachments"
|
||||
)
|
||||
|
||||
for attachment in attachments or []:
|
||||
media_content_id = attachment["media_content_id"]
|
||||
|
||||
# Special case for camera media sources
|
||||
if media_content_id.startswith("media-source://camera/"):
|
||||
# Extract entity_id from the media content ID
|
||||
entity_id = media_content_id.removeprefix("media-source://camera/")
|
||||
|
||||
# Get snapshot from camera
|
||||
image = await camera.async_get_image(hass, entity_id)
|
||||
|
||||
temp_filename = await hass.async_add_executor_job(
|
||||
_save_camera_snapshot, image
|
||||
)
|
||||
created_files.append(temp_filename)
|
||||
|
||||
resolved_attachments = []
|
||||
|
||||
for attachment in attachments:
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, attachment["media_content_id"], None
|
||||
resolved_attachments.append(
|
||||
conversation.Attachment(
|
||||
media_content_id=media_content_id,
|
||||
mime_type=image.content_type,
|
||||
path=temp_filename,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Handle regular media sources
|
||||
media = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||
if media.path is None:
|
||||
raise HomeAssistantError(
|
||||
"Only local attachments are currently supported"
|
||||
)
|
||||
resolved_attachments.append(
|
||||
conversation.Attachment(
|
||||
media_content_id=attachment["media_content_id"],
|
||||
url=media.url,
|
||||
media_content_id=media_content_id,
|
||||
mime_type=media.mime_type,
|
||||
path=media.path,
|
||||
)
|
||||
)
|
||||
|
||||
return await entity.internal_async_generate_data(
|
||||
GenDataTask(
|
||||
name=task_name,
|
||||
instructions=instructions,
|
||||
structure=structure,
|
||||
attachments=resolved_attachments,
|
||||
with async_get_chat_session(hass) as session:
|
||||
if created_files:
|
||||
|
||||
def cleanup_files() -> None:
|
||||
"""Cleanup temporary files."""
|
||||
for file in created_files:
|
||||
file.unlink(missing_ok=True)
|
||||
|
||||
@callback
|
||||
def cleanup_files_callback() -> None:
|
||||
"""Cleanup temporary files."""
|
||||
hass.async_add_executor_job(cleanup_files)
|
||||
|
||||
session.async_on_cleanup(cleanup_files_callback)
|
||||
|
||||
return await entity.internal_async_generate_data(
|
||||
session,
|
||||
GenDataTask(
|
||||
name=task_name,
|
||||
instructions=instructions,
|
||||
structure=structure,
|
||||
attachments=resolved_attachments or None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
@ -147,9 +147,6 @@ class Attachment:
|
||||
media_content_id: str
|
||||
"""Media content ID of the attachment."""
|
||||
|
||||
url: str
|
||||
"""URL of the attachment."""
|
||||
|
||||
mime_type: str
|
||||
"""MIME type of the attachment."""
|
||||
|
||||
|
@ -117,7 +117,6 @@ async def test_generate_data_service(
|
||||
for msg_attachment, attachment in zip(
|
||||
msg_attachments, task.attachments or [], strict=False
|
||||
):
|
||||
assert attachment.url == "http://example.com/media.mp4"
|
||||
assert attachment.mime_type == "video/mp4"
|
||||
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
||||
assert attachment.path == Path("media.mp4")
|
||||
|
@ -1,18 +1,26 @@
|
||||
"""Test tasks for the AI Task integration."""
|
||||
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
|
||||
from homeassistant.components.camera import Image
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@ -154,3 +162,83 @@ async def test_generate_data_attachments_not_supported(
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_data_mixed_attachments(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test generating data with both camera and regular media source attachments."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.camera.async_get_image",
|
||||
return_value=Image(content_type="image/jpeg", content=b"fake_camera_jpeg"),
|
||||
) as mock_get_image,
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(
|
||||
url="http://example.com/test.mp4",
|
||||
mime_type="video/mp4",
|
||||
path=Path("/media/test.mp4"),
|
||||
),
|
||||
) as mock_resolve_media,
|
||||
):
|
||||
await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Analyze these files",
|
||||
attachments=[
|
||||
{
|
||||
"media_content_id": "media-source://camera/camera.front_door",
|
||||
"media_content_type": "image/jpeg",
|
||||
},
|
||||
{
|
||||
"media_content_id": "media-source://media_player/video.mp4",
|
||||
"media_content_type": "video/mp4",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Verify both methods were called
|
||||
mock_get_image.assert_called_once_with(hass, "camera.front_door")
|
||||
mock_resolve_media.assert_called_once_with(
|
||||
hass, "media-source://media_player/video.mp4", None
|
||||
)
|
||||
|
||||
# Check attachments
|
||||
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
|
||||
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||
assert task.attachments is not None
|
||||
assert len(task.attachments) == 2
|
||||
|
||||
# Check camera attachment
|
||||
camera_attachment = task.attachments[0]
|
||||
assert (
|
||||
camera_attachment.media_content_id == "media-source://camera/camera.front_door"
|
||||
)
|
||||
assert camera_attachment.mime_type == "image/jpeg"
|
||||
assert isinstance(camera_attachment.path, Path)
|
||||
assert camera_attachment.path.suffix == ".jpg"
|
||||
|
||||
# Verify camera snapshot content
|
||||
assert camera_attachment.path.exists()
|
||||
content = await hass.async_add_executor_job(camera_attachment.path.read_bytes)
|
||||
assert content == b"fake_camera_jpeg"
|
||||
|
||||
# Trigger clean up
|
||||
async_fire_time_changed(
|
||||
hass,
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Verify the temporary file cleaned up
|
||||
assert not camera_attachment.path.exists()
|
||||
|
||||
# Check regular media attachment
|
||||
media_attachment = task.attachments[1]
|
||||
assert media_attachment.media_content_id == "media-source://media_player/video.mp4"
|
||||
assert media_attachment.mime_type == "video/mp4"
|
||||
assert media_attachment.path == Path("/media/test.mp4")
|
||||
|
Loading…
x
Reference in New Issue
Block a user