From 2be6acec03db64200c0254f068ffb0d4f51febb5 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 18 May 2025 02:45:03 +0000 Subject: [PATCH] Add AI Task integration --- CODEOWNERS | 2 + homeassistant/components/ai_task/__init__.py | 48 +++++++ homeassistant/components/ai_task/const.py | 34 +++++ homeassistant/components/ai_task/entity.py | 95 +++++++++++++ homeassistant/components/ai_task/http.py | 39 ++++++ .../components/ai_task/manifest.json | 9 ++ homeassistant/components/ai_task/task.py | 67 ++++++++++ homeassistant/const.py | 1 + tests/components/ai_task/__init__.py | 1 + tests/components/ai_task/conftest.py | 125 ++++++++++++++++++ .../ai_task/snapshots/test_task.ambr | 22 +++ tests/components/ai_task/test_entity.py | 41 ++++++ tests/components/ai_task/test_http.py | 39 ++++++ tests/components/ai_task/test_task.py | 53 ++++++++ 14 files changed, 576 insertions(+) create mode 100644 homeassistant/components/ai_task/__init__.py create mode 100644 homeassistant/components/ai_task/const.py create mode 100644 homeassistant/components/ai_task/entity.py create mode 100644 homeassistant/components/ai_task/http.py create mode 100644 homeassistant/components/ai_task/manifest.json create mode 100644 homeassistant/components/ai_task/task.py create mode 100644 tests/components/ai_task/__init__.py create mode 100644 tests/components/ai_task/conftest.py create mode 100644 tests/components/ai_task/snapshots/test_task.ambr create mode 100644 tests/components/ai_task/test_entity.py create mode 100644 tests/components/ai_task/test_http.py create mode 100644 tests/components/ai_task/test_task.py diff --git a/CODEOWNERS b/CODEOWNERS index 6670b411df4..1ceb6ff0e7d 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -57,6 +57,8 @@ build.json @home-assistant/supervisor /tests/components/aemet/ @Noltari /homeassistant/components/agent_dvr/ @ispysoftware /tests/components/agent_dvr/ @ispysoftware +/homeassistant/components/ai_task/ @home-assistant/core +/tests/components/ai_task/ @home-assistant/core /homeassistant/components/air_quality/ @home-assistant/core /tests/components/air_quality/ @home-assistant/core /homeassistant/components/airgradient/ @airgradienthq @joostlek diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py new file mode 100644 index 00000000000..a650dfdd0be --- /dev/null +++ b/homeassistant/components/ai_task/__init__.py @@ -0,0 +1,48 @@ +"""Integration to offer AI tasks to Home Assistant.""" + +import logging + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.typing import ConfigType + +from .const import DATA_COMPONENT, DOMAIN +from .entity import AITaskEntity +from .http import async_setup as async_setup_conversation_http +from .task import GenTextTask, GenTextTaskResult, GenTextTaskType, async_generate_text + +__all__ = [ + "DOMAIN", + "AITaskEntity", + "GenTextTask", + "GenTextTaskResult", + "GenTextTaskType", + "async_generate_text", + "async_setup", + "async_setup_entry", + "async_unload_entry", +] + +_LOGGER = logging.getLogger(__name__) + +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Register the process service.""" + entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) + hass.data[DATA_COMPONENT] = entity_component + async_setup_conversation_http(hass) + return True + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + return await hass.data[DATA_COMPONENT].async_setup_entry(entry) + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + return await hass.data[DATA_COMPONENT].async_unload_entry(entry) diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py new file mode 100644 index 00000000000..9d580ab39f5 --- /dev/null +++ b/homeassistant/components/ai_task/const.py @@ -0,0 +1,34 @@ +"""Constants for the AI Task integration.""" + +from __future__ import annotations + +from enum import StrEnum +from typing import TYPE_CHECKING + +from homeassistant.util.hass_dict import HassKey + +if TYPE_CHECKING: + from homeassistant.helpers.entity_component import EntityComponent + + from .entity import AITaskEntity + +DOMAIN = "ai_task" +DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) + +DEFAULT_SYSTEM_PROMPT = ( + "You are a Home Assistant expert and help users with their tasks." +) + + +class GenTextTaskType(StrEnum): + """Generate text task types. + + A task type describes the intent of the request in order to + match the right model for balance of cost and quality. + """ + + GENERATE = "generate" + """Generate content, which may target a higher quality result.""" + + SUMMARY = "summary" + """Summarize existing content, which be able to use a more cost effective model.""" diff --git a/homeassistant/components/ai_task/entity.py b/homeassistant/components/ai_task/entity.py new file mode 100644 index 00000000000..42c659a8289 --- /dev/null +++ b/homeassistant/components/ai_task/entity.py @@ -0,0 +1,95 @@ +"""Entity for the AI Task integration.""" + +from collections.abc import AsyncGenerator +import contextlib +from typing import final + +from homeassistant.components.conversation import ( + ChatLog, + UserContent, + async_get_chat_log, +) +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.helpers import llm +from homeassistant.helpers.chat_session import async_get_chat_session +from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.util import dt as dt_util + +from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN +from .task import GenTextTask, GenTextTaskResult + + +class AITaskEntity(RestoreEntity): + """Entity that supports conversations.""" + + _attr_should_poll = False + __last_activity: str | None = None + + @property + @final + def state(self) -> str | None: + """Return the state of the entity.""" + if self.__last_activity is None: + return None + return self.__last_activity + + async def async_internal_added_to_hass(self) -> None: + """Call when the entity is added to hass.""" + await super().async_internal_added_to_hass() + state = await self.async_get_last_state() + if ( + state is not None + and state.state is not None + and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ): + self.__last_activity = state.state + + @final + @contextlib.asynccontextmanager + async def _async_get_ai_task_chat_log( + self, + task: GenTextTask, + ) -> AsyncGenerator[ChatLog]: + """Context manager used to manage the ChatLog used during an AI Task.""" + # pylint: disable-next=contextmanager-generator-missing-cleanup + with ( + async_get_chat_session(self.hass) as session, + async_get_chat_log( + self.hass, + session, + None, + ) as chat_log, + ): + await chat_log.async_provide_llm_data( + llm.LLMContext( + platform=self.platform.domain, + context=None, + language=None, + assistant=DOMAIN, + device_id=None, + ), + user_llm_prompt=DEFAULT_SYSTEM_PROMPT, + ) + + chat_log.async_add_user_content(UserContent(task.instructions)) + + yield chat_log + + @final + async def internal_async_generate_text( + self, + task: GenTextTask, + ) -> GenTextTaskResult: + """Run a gen text task.""" + self.__last_activity = dt_util.utcnow().isoformat() + self.async_write_ha_state() + async with self._async_get_ai_task_chat_log(task) as chat_log: + return await self._async_generate_text(task, chat_log) + + async def _async_generate_text( + self, + task: GenTextTask, + chat_log: ChatLog, + ) -> GenTextTaskResult: + """Handle a gen text task.""" + raise NotImplementedError diff --git a/homeassistant/components/ai_task/http.py b/homeassistant/components/ai_task/http.py new file mode 100644 index 00000000000..c79694194aa --- /dev/null +++ b/homeassistant/components/ai_task/http.py @@ -0,0 +1,39 @@ +"""HTTP endpoint for AI Task integration.""" + +from typing import Any + +import voluptuous as vol + +from homeassistant.components import websocket_api +from homeassistant.core import HomeAssistant, callback + +from .task import GenTextTaskType, async_generate_text + + +@callback +def async_setup(hass: HomeAssistant) -> None: + """Set up the HTTP API for the conversation integration.""" + websocket_api.async_register_command(hass, websocket_generate_text) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "ai_task/generate_text", + vol.Required("task_name"): str, + vol.Required("entity_id"): str, + vol.Required("task_type"): (lambda v: GenTextTaskType(v)), # pylint: disable=unnecessary-lambda + vol.Required("instructions"): str, + } +) +@websocket_api.require_admin +@websocket_api.async_response +async def websocket_generate_text( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Run a generate text task.""" + msg.pop("type") + msg_id = msg.pop("id") + result = await async_generate_text(hass=hass, **msg) + connection.send_result(msg_id, result.as_dict()) diff --git a/homeassistant/components/ai_task/manifest.json b/homeassistant/components/ai_task/manifest.json new file mode 100644 index 00000000000..c685410530d --- /dev/null +++ b/homeassistant/components/ai_task/manifest.json @@ -0,0 +1,9 @@ +{ + "domain": "ai_task", + "name": "AI Task", + "codeowners": ["@home-assistant/core"], + "dependencies": ["conversation"], + "documentation": "https://www.home-assistant.io/integrations/ai_task", + "integration_type": "system", + "quality_scale": "internal" +} diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py new file mode 100644 index 00000000000..2f6c901dce8 --- /dev/null +++ b/homeassistant/components/ai_task/task.py @@ -0,0 +1,67 @@ +"""AI tasks to be handled by agents.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from homeassistant.core import HomeAssistant + +from .const import DATA_COMPONENT, GenTextTaskType + + +async def async_generate_text( + hass: HomeAssistant, + *, + task_name: str, + entity_id: str, + task_type: GenTextTaskType, + instructions: str, +) -> GenTextTaskResult: + """Run a task in the AI Task integration.""" + entity = hass.data[DATA_COMPONENT].get_entity(entity_id) + if entity is None: + raise ValueError(f"AI Task entity {entity_id} not found") + + return await entity.internal_async_generate_text( + GenTextTask( + name=task_name, + type=task_type, + instructions=instructions, + ) + ) + + +@dataclass(slots=True) +class GenTextTask: + """Gen text task to be processed.""" + + name: str + """Name of the task.""" + + type: GenTextTaskType + """Type of the task.""" + + instructions: str + """Instructions on what needs to be done.""" + + def __str__(self) -> str: + """Return task as a string.""" + return f"" + + +@dataclass(slots=True) +class GenTextTaskResult: + """Result of gen text task.""" + + conversation_id: str + """Unique identifier for the conversation.""" + + result: str + """Result of the task.""" + + def as_dict(self) -> dict[str, str]: + """Return result as a dict.""" + return { + "conversation_id": self.conversation_id, + "result": self.result, + } diff --git a/homeassistant/const.py b/homeassistant/const.py index f692f428920..0abdcd59b77 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -40,6 +40,7 @@ PLATFORM_FORMAT: Final = "{platform}.{domain}" class Platform(StrEnum): """Available entity platforms.""" + AI_TASK = "ai_task" AIR_QUALITY = "air_quality" ALARM_CONTROL_PANEL = "alarm_control_panel" ASSIST_SATELLITE = "assist_satellite" diff --git a/tests/components/ai_task/__init__.py b/tests/components/ai_task/__init__.py new file mode 100644 index 00000000000..b4ca4688eb4 --- /dev/null +++ b/tests/components/ai_task/__init__.py @@ -0,0 +1 @@ +"""Tests for the AI Task integration.""" diff --git a/tests/components/ai_task/conftest.py b/tests/components/ai_task/conftest.py new file mode 100644 index 00000000000..a8cda8b570b --- /dev/null +++ b/tests/components/ai_task/conftest.py @@ -0,0 +1,125 @@ +"""Test helpers for AI Task integration.""" + +import pytest + +from homeassistant.components.ai_task import ( + DOMAIN, + AITaskEntity, + GenTextTask, + GenTextTaskResult, +) +from homeassistant.components.conversation import AssistantContent, ChatLog +from homeassistant.config_entries import ConfigEntry, ConfigFlow +from homeassistant.const import Platform +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.setup import async_setup_component + +from tests.common import ( + MockConfigEntry, + MockModule, + MockPlatform, + mock_config_flow, + mock_integration, + mock_platform, +) + +TEST_DOMAIN = "test" +TEST_ENTITY_ID = "ai_task.test_task_entity" + + +class MockAITaskEntity(AITaskEntity): + """Mock AI Task entity for testing.""" + + _attr_name = "Test Task Entity" + + def __init__(self) -> None: + """Initialize the mock entity.""" + super().__init__() + self.mock_generate_text_tasks = [] + + async def _async_generate_text( + self, task: GenTextTask, chat_log: ChatLog + ) -> GenTextTaskResult: + """Mock handling of generate text task.""" + self.mock_generate_text_tasks.append(task) + chat_log.async_add_assistant_content_without_tools( + AssistantContent(self.entity_id, "Mock result") + ) + return GenTextTaskResult( + conversation_id=chat_log.conversation_id, + result="Mock result", + ) + + +@pytest.fixture +def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: + """Mock a configuration entry for AI Task.""" + entry = MockConfigEntry(domain=TEST_DOMAIN, entry_id="mock-test-entry") + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +def mock_ai_task_entity( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> MockAITaskEntity: + """Mock AI Task entity.""" + return MockAITaskEntity() + + +@pytest.fixture +async def init_components( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_ai_task_entity: MockAITaskEntity, +): + """Initialize the AI Task integration with a mock entity.""" + assert await async_setup_component(hass, "homeassistant", {}) + + async def async_setup_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Set up test config entry.""" + await hass.config_entries.async_forward_entry_setups( + config_entry, [Platform.AI_TASK] + ) + return True + + async def async_unload_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Unload test config entry.""" + await hass.config_entries.async_forward_entry_unload( + config_entry, Platform.AI_TASK + ) + return True + + mock_integration( + hass, + MockModule( + TEST_DOMAIN, + async_setup_entry=async_setup_entry_init, + async_unload_entry=async_unload_entry_init, + ), + ) + + async def async_setup_entry_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, + ) -> None: + """Set up test tts platform via config entry.""" + async_add_entities([mock_ai_task_entity]) + + mock_platform( + hass, + f"{TEST_DOMAIN}.{DOMAIN}", + MockPlatform(async_setup_entry=async_setup_entry_platform), + ) + + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + + with mock_config_flow(TEST_DOMAIN, ConfigFlow): + assert await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() diff --git a/tests/components/ai_task/snapshots/test_task.ambr b/tests/components/ai_task/snapshots/test_task.ambr new file mode 100644 index 00000000000..6d155c82a68 --- /dev/null +++ b/tests/components/ai_task/snapshots/test_task.ambr @@ -0,0 +1,22 @@ +# serializer version: 1 +# name: test_run_text_task_updates_chat_log + list([ + dict({ + 'content': ''' + You are a Home Assistant expert and help users with their tasks. + Current time is 15:59:00. Today's date is 2025-06-14. + ''', + 'role': 'system', + }), + dict({ + 'content': 'Test prompt', + 'role': 'user', + }), + dict({ + 'agent_id': 'ai_task.test_task_entity', + 'content': 'Mock result', + 'role': 'assistant', + 'tool_calls': None, + }), + ]) +# --- diff --git a/tests/components/ai_task/test_entity.py b/tests/components/ai_task/test_entity.py new file mode 100644 index 00000000000..1689306bd21 --- /dev/null +++ b/tests/components/ai_task/test_entity.py @@ -0,0 +1,41 @@ +"""Tests for the AI Task entity model.""" + +from freezegun import freeze_time + +from homeassistant.components.ai_task import GenTextTaskType, async_generate_text +from homeassistant.const import STATE_UNKNOWN +from homeassistant.core import HomeAssistant + +from .conftest import TEST_ENTITY_ID, MockAITaskEntity + +from tests.common import MockConfigEntry + + +@freeze_time("2025-06-08 16:28:13") +async def test_state_generate_text( + hass: HomeAssistant, + init_components: None, + mock_config_entry: MockConfigEntry, + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test the state of the AI Task entity is updated when generating text.""" + entity = hass.states.get(TEST_ENTITY_ID) + assert entity is not None + assert entity.state == STATE_UNKNOWN + + result = await async_generate_text( + hass, + task_name="Test task", + entity_id=TEST_ENTITY_ID, + task_type=GenTextTaskType.SUMMARY, + instructions="Test prompt", + ) + assert result.result == "Mock result" + + entity = hass.states.get(TEST_ENTITY_ID) + assert entity.state == "2025-06-08T16:28:13+00:00" + + assert mock_ai_task_entity.mock_generate_text_tasks + task = mock_ai_task_entity.mock_generate_text_tasks[0] + assert task.type == GenTextTaskType.SUMMARY + assert task.instructions == "Test prompt" diff --git a/tests/components/ai_task/test_http.py b/tests/components/ai_task/test_http.py new file mode 100644 index 00000000000..5421025bf3a --- /dev/null +++ b/tests/components/ai_task/test_http.py @@ -0,0 +1,39 @@ +"""Test the HTTP API for AI Task integration.""" + +from homeassistant.const import STATE_UNKNOWN +from homeassistant.core import HomeAssistant + +from .conftest import TEST_ENTITY_ID + +from tests.typing import WebSocketGenerator + + +async def test_ws_generate_text( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components: None, +) -> None: + """Test running a generate text task via the WebSocket API.""" + entity = hass.states.get(TEST_ENTITY_ID) + assert entity is not None + assert entity.state == STATE_UNKNOWN + + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "ai_task/generate_text", + "task_name": "Test Task", + "entity_id": TEST_ENTITY_ID, + "task_type": "summary", + "instructions": "Test prompt", + } + ) + + msg = await client.receive_json() + + assert msg["success"] + assert msg["result"]["result"] == "Mock result" + + entity = hass.states.get(TEST_ENTITY_ID) + assert entity.state != STATE_UNKNOWN diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py new file mode 100644 index 00000000000..4e355bcf1bd --- /dev/null +++ b/tests/components/ai_task/test_task.py @@ -0,0 +1,53 @@ +"""Test tasks for the AI Task integration.""" + +from freezegun import freeze_time +import pytest +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components.ai_task import GenTextTaskType, async_generate_text +from homeassistant.components.conversation import async_get_chat_log +from homeassistant.core import HomeAssistant +from homeassistant.helpers import chat_session + +from .conftest import TEST_ENTITY_ID + + +async def test_run_text_task_unknown_entity( + hass: HomeAssistant, + init_components: None, +) -> None: + """Test running a text task with an unknown entity.""" + + with pytest.raises( + ValueError, match="AI Task entity ai_task.unknown_entity not found" + ): + await async_generate_text( + hass, + task_name="Test Task", + entity_id="ai_task.unknown_entity", + task_type="summary", + instructions="Test prompt", + ) + + +@freeze_time("2025-06-14 22:59:00") +async def test_run_text_task_updates_chat_log( + hass: HomeAssistant, + init_components: None, + snapshot: SnapshotAssertion, +) -> None: + """Test that running a text task updates the chat log.""" + result = await async_generate_text( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + task_type=GenTextTaskType.SUMMARY, + instructions="Test prompt", + ) + assert result.result == "Mock result" + + with ( + chat_session.async_get_chat_session(hass, result.conversation_id) as session, + async_get_chat_log(hass, session) as chat_log, + ): + assert chat_log.content == snapshot