From 008e2a3d10070e7ad5ac93ff679a9836cbc6ab8d Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 6 Jul 2025 19:33:41 +0200 Subject: [PATCH] Add attachment support to AI task (#148120) --- homeassistant/components/ai_task/__init__.py | 7 ++- homeassistant/components/ai_task/const.py | 4 ++ .../components/ai_task/manifest.json | 2 +- .../components/ai_task/services.yaml | 6 ++ homeassistant/components/ai_task/strings.json | 4 ++ homeassistant/components/ai_task/task.py | 45 +++++++++++++- tests/components/ai_task/conftest.py | 4 +- tests/components/ai_task/test_init.py | 59 +++++++++++++++---- tests/components/ai_task/test_task.py | 37 ++++++++++-- 9 files changed, 147 insertions(+), 21 deletions(-) diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index 95c080cc472..a472b0db131 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -20,6 +20,7 @@ from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType from .const import ( + ATTR_ATTACHMENTS, ATTR_INSTRUCTIONS, ATTR_REQUIRED, ATTR_STRUCTURE, @@ -32,7 +33,7 @@ from .const import ( ) from .entity import AITaskEntity from .http import async_setup as async_setup_http -from .task import GenDataTask, GenDataTaskResult, async_generate_data +from .task import GenDataTask, GenDataTaskResult, PlayMediaWithId, async_generate_data __all__ = [ "DOMAIN", @@ -40,6 +41,7 @@ __all__ = [ "AITaskEntityFeature", "GenDataTask", "GenDataTaskResult", + "PlayMediaWithId", "async_generate_data", "async_setup", "async_setup_entry", @@ -92,6 +94,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: vol.Schema({str: STRUCTURE_FIELD_SCHEMA}), _validate_structure_fields, ), + vol.Optional(ATTR_ATTACHMENTS): vol.All( + cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})] + ), } ), supports_response=SupportsResponse.ONLY, diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py index fa8702ed69e..09948e9b673 100644 --- a/homeassistant/components/ai_task/const.py +++ b/homeassistant/components/ai_task/const.py @@ -23,6 +23,7 @@ ATTR_INSTRUCTIONS: Final = "instructions" ATTR_TASK_NAME: Final = "task_name" ATTR_STRUCTURE: Final = "structure" ATTR_REQUIRED: Final = "required" +ATTR_ATTACHMENTS: Final = "attachments" DEFAULT_SYSTEM_PROMPT = ( "You are a Home Assistant expert and help users with their tasks." @@ -34,3 +35,6 @@ class AITaskEntityFeature(IntFlag): GENERATE_DATA = 1 """Generate data based on instructions.""" + + SUPPORT_ATTACHMENTS = 2 + """Support attachments with generate data.""" diff --git a/homeassistant/components/ai_task/manifest.json b/homeassistant/components/ai_task/manifest.json index c685410530d..c3e33e7d411 100644 --- a/homeassistant/components/ai_task/manifest.json +++ b/homeassistant/components/ai_task/manifest.json @@ -2,7 +2,7 @@ "domain": "ai_task", "name": "AI Task", "codeowners": ["@home-assistant/core"], - "dependencies": ["conversation"], + "dependencies": ["conversation", "media_source"], "documentation": "https://www.home-assistant.io/integrations/ai_task", "integration_type": "system", "quality_scale": "internal" diff --git a/homeassistant/components/ai_task/services.yaml b/homeassistant/components/ai_task/services.yaml index d55b0e60fac..4298ab62a07 100644 --- a/homeassistant/components/ai_task/services.yaml +++ b/homeassistant/components/ai_task/services.yaml @@ -23,3 +23,9 @@ generate_data: example: '{ "name": { "selector": { "text": }, "description": "Name of the user", "required": "True" } } }, "age": { "selector": { "number": }, "description": "Age of the user" } }' selector: object: + attachments: + required: false + selector: + media: + accept: + - "*" diff --git a/homeassistant/components/ai_task/strings.json b/homeassistant/components/ai_task/strings.json index 92106c3baca..261381b7c31 100644 --- a/homeassistant/components/ai_task/strings.json +++ b/homeassistant/components/ai_task/strings.json @@ -19,6 +19,10 @@ "structure": { "name": "Structured output", "description": "When set, the AI Task will output fields with this in structure. The structure is a dictionary where the keys are the field names and the values contain a 'description', a 'selector', and an optional 'required' field." + }, + "attachments": { + "name": "Attachments", + "description": "List of files to attach for multi-modal AI analysis." } } } diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index b6defbfad31..72d1018210c 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -2,17 +2,30 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import Any import voluptuous as vol +from homeassistant.components import media_source from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature +@dataclass(slots=True) +class PlayMediaWithId(media_source.PlayMedia): + """Play media with a media content ID.""" + + media_content_id: str + """Media source ID to play.""" + + def __str__(self) -> str: + """Return media source ID as a string.""" + return f"" + + async def async_generate_data( hass: HomeAssistant, *, @@ -20,6 +33,7 @@ async def async_generate_data( entity_id: str | None = None, instructions: str, structure: vol.Schema | None = None, + attachments: list[dict] | None = None, ) -> GenDataTaskResult: """Run a task in the AI Task integration.""" if entity_id is None: @@ -37,11 +51,37 @@ async def async_generate_data( f"AI Task entity {entity_id} does not support generating data" ) + # Resolve attachments + resolved_attachments: list[PlayMediaWithId] | None = None + + if attachments: + if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features: + raise HomeAssistantError( + f"AI Task entity {entity_id} does not support attachments" + ) + + resolved_attachments = [] + + for attachment in attachments: + media = await media_source.async_resolve_media( + hass, attachment["media_content_id"], None + ) + resolved_attachments.append( + PlayMediaWithId( + **{ + field.name: getattr(media, field.name) + for field in fields(media) + }, + media_content_id=attachment["media_content_id"], + ) + ) + return await entity.internal_async_generate_data( GenDataTask( name=task_name, instructions=instructions, structure=structure, + attachments=resolved_attachments, ) ) @@ -59,6 +99,9 @@ class GenDataTask: structure: vol.Schema | None = None """Optional structure for the data to be generated.""" + attachments: list[PlayMediaWithId] | None = None + """List of attachments to go along the instructions.""" + def __str__(self) -> str: """Return task as a string.""" return f"" diff --git a/tests/components/ai_task/conftest.py b/tests/components/ai_task/conftest.py index e80e70ddaed..05d34b15ddc 100644 --- a/tests/components/ai_task/conftest.py +++ b/tests/components/ai_task/conftest.py @@ -35,7 +35,9 @@ class MockAITaskEntity(AITaskEntity): """Mock AI Task entity for testing.""" _attr_name = "Test Task Entity" - _attr_supported_features = AITaskEntityFeature.GENERATE_DATA + _attr_supported_features = ( + AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS + ) def __init__(self) -> None: """Initialize the mock entity.""" diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py index d32b09adec5..840285493ac 100644 --- a/tests/components/ai_task/test_init.py +++ b/tests/components/ai_task/test_init.py @@ -1,11 +1,13 @@ """Test initialization of the AI Task component.""" from typing import Any +from unittest.mock import patch from freezegun.api import FrozenDateTimeFactory import pytest import voluptuous as vol +from homeassistant.components import media_source from homeassistant.components.ai_task import AITaskPreferences from homeassistant.components.ai_task.const import DATA_PREFERENCES from homeassistant.core import HomeAssistant @@ -58,7 +60,15 @@ async def test_preferences_storage_load( ), ( {}, - {"entity_id": TEST_ENTITY_ID}, + { + "entity_id": TEST_ENTITY_ID, + "attachments": [ + { + "media_content_id": "media-source://mock/blah_blah_blah.mp4", + "media_content_type": "video/mp4", + } + ], + }, ), ], ) @@ -68,25 +78,50 @@ async def test_generate_data_service( freezer: FrozenDateTimeFactory, set_preferences: dict[str, str | None], msg_extra: dict[str, str], + mock_ai_task_entity: MockAITaskEntity, ) -> None: """Test the generate data service.""" preferences = hass.data[DATA_PREFERENCES] preferences.async_set_preferences(**set_preferences) - result = await hass.services.async_call( - "ai_task", - "generate_data", - { - "task_name": "Test Name", - "instructions": "Test prompt", - } - | msg_extra, - blocking=True, - return_response=True, - ) + with patch( + "homeassistant.components.media_source.async_resolve_media", + return_value=media_source.PlayMedia( + url="http://example.com/media.mp4", + mime_type="video/mp4", + ), + ): + result = await hass.services.async_call( + "ai_task", + "generate_data", + { + "task_name": "Test Name", + "instructions": "Test prompt", + } + | msg_extra, + blocking=True, + return_response=True, + ) assert result["data"] == "Mock result" + assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1 + task = mock_ai_task_entity.mock_generate_data_tasks[0] + + assert len(task.attachments or []) == len( + msg_attachments := msg_extra.get("attachments", []) + ) + + for msg_attachment, attachment in zip( + msg_attachments, task.attachments or [], strict=False + ): + assert attachment.url == "http://example.com/media.mp4" + assert attachment.mime_type == "video/mp4" + assert attachment.media_content_id == msg_attachment["media_content_id"] + assert ( + str(attachment) == f"" + ) + async def test_generate_data_service_structure_fields( hass: HomeAssistant, diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py index bed760c8a1d..b11d96823cc 100644 --- a/tests/components/ai_task/test_task.py +++ b/tests/components/ai_task/test_task.py @@ -16,13 +16,13 @@ from .conftest import TEST_ENTITY_ID, MockAITaskEntity from tests.typing import WebSocketGenerator -async def test_run_task_preferred_entity( +async def test_generate_data_preferred_entity( hass: HomeAssistant, init_components: None, mock_ai_task_entity: MockAITaskEntity, hass_ws_client: WebSocketGenerator, ) -> None: - """Test running a task with an unknown entity.""" + """Test generating data with entity via preferences.""" client = await hass_ws_client(hass) with pytest.raises( @@ -90,11 +90,11 @@ async def test_run_task_preferred_entity( ) -async def test_run_data_task_unknown_entity( +async def test_generate_data_unknown_entity( hass: HomeAssistant, init_components: None, ) -> None: - """Test running a data task with an unknown entity.""" + """Test generating data with an unknown entity.""" with pytest.raises( HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found" @@ -113,7 +113,7 @@ async def test_run_data_task_updates_chat_log( init_components: None, snapshot: SnapshotAssertion, ) -> None: - """Test that running a data task updates the chat log.""" + """Test that generating data updates the chat log.""" result = await async_generate_data( hass, task_name="Test Task", @@ -127,3 +127,30 @@ async def test_run_data_task_updates_chat_log( async_get_chat_log(hass, session) as chat_log, ): assert chat_log.content == snapshot + + +async def test_generate_data_attachments_not_supported( + hass: HomeAssistant, + init_components: None, + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test generating data with attachments when entity doesn't support them.""" + # Remove attachment support from the entity + mock_ai_task_entity._attr_supported_features = AITaskEntityFeature.GENERATE_DATA + + with pytest.raises( + HomeAssistantError, + match="AI Task entity ai_task.test_task_entity does not support attachments", + ): + await async_generate_data( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + attachments=[ + { + "media_content_id": "media-source://mock/test.mp4", + "media_content_type": "video/mp4", + } + ], + )