mirror of
https://github.com/home-assistant/core.git
synced 2025-07-29 16:17:20 +00:00
Allow AI Task to handle camera attachments (#148753)
This commit is contained in:
parent
816977dd75
commit
e2cc51f21d
@ -13,7 +13,7 @@ from homeassistant.components.conversation import (
|
|||||||
)
|
)
|
||||||
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.helpers.chat_session import async_get_chat_session
|
from homeassistant.helpers.chat_session import ChatSession
|
||||||
from homeassistant.helpers.restore_state import RestoreEntity
|
from homeassistant.helpers.restore_state import RestoreEntity
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
@ -56,12 +56,12 @@ class AITaskEntity(RestoreEntity):
|
|||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _async_get_ai_task_chat_log(
|
async def _async_get_ai_task_chat_log(
|
||||||
self,
|
self,
|
||||||
|
session: ChatSession,
|
||||||
task: GenDataTask,
|
task: GenDataTask,
|
||||||
) -> AsyncGenerator[ChatLog]:
|
) -> AsyncGenerator[ChatLog]:
|
||||||
"""Context manager used to manage the ChatLog used during an AI Task."""
|
"""Context manager used to manage the ChatLog used during an AI Task."""
|
||||||
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
||||||
with (
|
with (
|
||||||
async_get_chat_session(self.hass) as session,
|
|
||||||
async_get_chat_log(
|
async_get_chat_log(
|
||||||
self.hass,
|
self.hass,
|
||||||
session,
|
session,
|
||||||
@ -88,12 +88,13 @@ class AITaskEntity(RestoreEntity):
|
|||||||
@final
|
@final
|
||||||
async def internal_async_generate_data(
|
async def internal_async_generate_data(
|
||||||
self,
|
self,
|
||||||
|
session: ChatSession,
|
||||||
task: GenDataTask,
|
task: GenDataTask,
|
||||||
) -> GenDataTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Run a gen data task."""
|
"""Run a gen data task."""
|
||||||
self.__last_activity = dt_util.utcnow().isoformat()
|
self.__last_activity = dt_util.utcnow().isoformat()
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
async with self._async_get_ai_task_chat_log(task) as chat_log:
|
async with self._async_get_ai_task_chat_log(session, task) as chat_log:
|
||||||
return await self._async_generate_data(task, chat_log)
|
return await self._async_generate_data(task, chat_log)
|
||||||
|
|
||||||
async def _async_generate_data(
|
async def _async_generate_data(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
{
|
{
|
||||||
"domain": "ai_task",
|
"domain": "ai_task",
|
||||||
"name": "AI Task",
|
"name": "AI Task",
|
||||||
|
"after_dependencies": ["camera"],
|
||||||
"codeowners": ["@home-assistant/core"],
|
"codeowners": ["@home-assistant/core"],
|
||||||
"dependencies": ["conversation", "media_source"],
|
"dependencies": ["conversation", "media_source"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
||||||
|
@ -3,17 +3,32 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation, media_source
|
from homeassistant.components import camera, conversation, media_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers.chat_session import async_get_chat_session
|
||||||
|
|
||||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||||
|
|
||||||
|
|
||||||
|
def _save_camera_snapshot(image: camera.Image) -> Path:
|
||||||
|
"""Save camera snapshot to temp file."""
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="wb",
|
||||||
|
suffix=mimetypes.guess_extension(image.content_type, False),
|
||||||
|
delete=False,
|
||||||
|
) as temp_file:
|
||||||
|
temp_file.write(image.content)
|
||||||
|
return Path(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
async def async_generate_data(
|
async def async_generate_data(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
*,
|
*,
|
||||||
@ -40,40 +55,78 @@ async def async_generate_data(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Resolve attachments
|
# Resolve attachments
|
||||||
resolved_attachments: list[conversation.Attachment] | None = None
|
resolved_attachments: list[conversation.Attachment] = []
|
||||||
|
created_files: list[Path] = []
|
||||||
|
|
||||||
if attachments:
|
if (
|
||||||
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
attachments
|
||||||
|
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
|
||||||
|
):
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"AI Task entity {entity_id} does not support attachments"
|
f"AI Task entity {entity_id} does not support attachments"
|
||||||
)
|
)
|
||||||
|
|
||||||
resolved_attachments = []
|
for attachment in attachments or []:
|
||||||
|
media_content_id = attachment["media_content_id"]
|
||||||
|
|
||||||
for attachment in attachments:
|
# Special case for camera media sources
|
||||||
media = await media_source.async_resolve_media(
|
if media_content_id.startswith("media-source://camera/"):
|
||||||
hass, attachment["media_content_id"], None
|
# Extract entity_id from the media content ID
|
||||||
|
entity_id = media_content_id.removeprefix("media-source://camera/")
|
||||||
|
|
||||||
|
# Get snapshot from camera
|
||||||
|
image = await camera.async_get_image(hass, entity_id)
|
||||||
|
|
||||||
|
temp_filename = await hass.async_add_executor_job(
|
||||||
|
_save_camera_snapshot, image
|
||||||
)
|
)
|
||||||
|
created_files.append(temp_filename)
|
||||||
|
|
||||||
|
resolved_attachments.append(
|
||||||
|
conversation.Attachment(
|
||||||
|
media_content_id=media_content_id,
|
||||||
|
mime_type=image.content_type,
|
||||||
|
path=temp_filename,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Handle regular media sources
|
||||||
|
media = await media_source.async_resolve_media(hass, media_content_id, None)
|
||||||
if media.path is None:
|
if media.path is None:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
"Only local attachments are currently supported"
|
"Only local attachments are currently supported"
|
||||||
)
|
)
|
||||||
resolved_attachments.append(
|
resolved_attachments.append(
|
||||||
conversation.Attachment(
|
conversation.Attachment(
|
||||||
media_content_id=attachment["media_content_id"],
|
media_content_id=media_content_id,
|
||||||
url=media.url,
|
|
||||||
mime_type=media.mime_type,
|
mime_type=media.mime_type,
|
||||||
path=media.path,
|
path=media.path,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with async_get_chat_session(hass) as session:
|
||||||
|
if created_files:
|
||||||
|
|
||||||
|
def cleanup_files() -> None:
|
||||||
|
"""Cleanup temporary files."""
|
||||||
|
for file in created_files:
|
||||||
|
file.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def cleanup_files_callback() -> None:
|
||||||
|
"""Cleanup temporary files."""
|
||||||
|
hass.async_add_executor_job(cleanup_files)
|
||||||
|
|
||||||
|
session.async_on_cleanup(cleanup_files_callback)
|
||||||
|
|
||||||
return await entity.internal_async_generate_data(
|
return await entity.internal_async_generate_data(
|
||||||
|
session,
|
||||||
GenDataTask(
|
GenDataTask(
|
||||||
name=task_name,
|
name=task_name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
structure=structure,
|
structure=structure,
|
||||||
attachments=resolved_attachments,
|
attachments=resolved_attachments or None,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,9 +147,6 @@ class Attachment:
|
|||||||
media_content_id: str
|
media_content_id: str
|
||||||
"""Media content ID of the attachment."""
|
"""Media content ID of the attachment."""
|
||||||
|
|
||||||
url: str
|
|
||||||
"""URL of the attachment."""
|
|
||||||
|
|
||||||
mime_type: str
|
mime_type: str
|
||||||
"""MIME type of the attachment."""
|
"""MIME type of the attachment."""
|
||||||
|
|
||||||
|
@ -117,7 +117,6 @@ async def test_generate_data_service(
|
|||||||
for msg_attachment, attachment in zip(
|
for msg_attachment, attachment in zip(
|
||||||
msg_attachments, task.attachments or [], strict=False
|
msg_attachments, task.attachments or [], strict=False
|
||||||
):
|
):
|
||||||
assert attachment.url == "http://example.com/media.mp4"
|
|
||||||
assert attachment.mime_type == "video/mp4"
|
assert attachment.mime_type == "video/mp4"
|
||||||
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
assert attachment.media_content_id == msg_attachment["media_content_id"]
|
||||||
assert attachment.path == Path("media.mp4")
|
assert attachment.path == Path("media.mp4")
|
||||||
|
@ -1,18 +1,26 @@
|
|||||||
"""Test tasks for the AI Task integration."""
|
"""Test tasks for the AI Task integration."""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
|
from homeassistant.components import media_source
|
||||||
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
|
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
|
||||||
|
from homeassistant.components.camera import Image
|
||||||
from homeassistant.components.conversation import async_get_chat_log
|
from homeassistant.components.conversation import async_get_chat_log
|
||||||
from homeassistant.const import STATE_UNKNOWN
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import chat_session
|
from homeassistant.helpers import chat_session
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||||
|
|
||||||
|
from tests.common import async_fire_time_changed
|
||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
@ -154,3 +162,83 @@ async def test_generate_data_attachments_not_supported(
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_data_mixed_attachments(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
|
) -> None:
|
||||||
|
"""Test generating data with both camera and regular media source attachments."""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.camera.async_get_image",
|
||||||
|
return_value=Image(content_type="image/jpeg", content=b"fake_camera_jpeg"),
|
||||||
|
) as mock_get_image,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
return_value=media_source.PlayMedia(
|
||||||
|
url="http://example.com/test.mp4",
|
||||||
|
mime_type="video/mp4",
|
||||||
|
path=Path("/media/test.mp4"),
|
||||||
|
),
|
||||||
|
) as mock_resolve_media,
|
||||||
|
):
|
||||||
|
await async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=TEST_ENTITY_ID,
|
||||||
|
instructions="Analyze these files",
|
||||||
|
attachments=[
|
||||||
|
{
|
||||||
|
"media_content_id": "media-source://camera/camera.front_door",
|
||||||
|
"media_content_type": "image/jpeg",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"media_content_id": "media-source://media_player/video.mp4",
|
||||||
|
"media_content_type": "video/mp4",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify both methods were called
|
||||||
|
mock_get_image.assert_called_once_with(hass, "camera.front_door")
|
||||||
|
mock_resolve_media.assert_called_once_with(
|
||||||
|
hass, "media-source://media_player/video.mp4", None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check attachments
|
||||||
|
assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1
|
||||||
|
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||||
|
assert task.attachments is not None
|
||||||
|
assert len(task.attachments) == 2
|
||||||
|
|
||||||
|
# Check camera attachment
|
||||||
|
camera_attachment = task.attachments[0]
|
||||||
|
assert (
|
||||||
|
camera_attachment.media_content_id == "media-source://camera/camera.front_door"
|
||||||
|
)
|
||||||
|
assert camera_attachment.mime_type == "image/jpeg"
|
||||||
|
assert isinstance(camera_attachment.path, Path)
|
||||||
|
assert camera_attachment.path.suffix == ".jpg"
|
||||||
|
|
||||||
|
# Verify camera snapshot content
|
||||||
|
assert camera_attachment.path.exists()
|
||||||
|
content = await hass.async_add_executor_job(camera_attachment.path.read_bytes)
|
||||||
|
assert content == b"fake_camera_jpeg"
|
||||||
|
|
||||||
|
# Trigger clean up
|
||||||
|
async_fire_time_changed(
|
||||||
|
hass,
|
||||||
|
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Verify the temporary file cleaned up
|
||||||
|
assert not camera_attachment.path.exists()
|
||||||
|
|
||||||
|
# Check regular media attachment
|
||||||
|
media_attachment = task.attachments[1]
|
||||||
|
assert media_attachment.media_content_id == "media-source://media_player/video.mp4"
|
||||||
|
assert media_attachment.mime_type == "video/mp4"
|
||||||
|
assert media_attachment.path == Path("/media/test.mp4")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user