diff --git a/CODEOWNERS b/CODEOWNERS index 6670b411df4..1ceb6ff0e7d 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -57,6 +57,8 @@ build.json @home-assistant/supervisor /tests/components/aemet/ @Noltari /homeassistant/components/agent_dvr/ @ispysoftware /tests/components/agent_dvr/ @ispysoftware +/homeassistant/components/ai_task/ @home-assistant/core +/tests/components/ai_task/ @home-assistant/core /homeassistant/components/air_quality/ @home-assistant/core /tests/components/air_quality/ @home-assistant/core /homeassistant/components/airgradient/ @airgradienthq @joostlek diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py new file mode 100644 index 00000000000..8b3d6e04966 --- /dev/null +++ b/homeassistant/components/ai_task/__init__.py @@ -0,0 +1,125 @@ +"""Integration to offer AI tasks to Home Assistant.""" + +import logging + +import voluptuous as vol + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import ( + HassJobType, + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, + callback, +) +from homeassistant.helpers import config_validation as cv, storage +from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType + +from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN, AITaskEntityFeature +from .entity import AITaskEntity +from .http import async_setup as async_setup_conversation_http +from .task import GenTextTask, GenTextTaskResult, async_generate_text + +__all__ = [ + "DOMAIN", + "AITaskEntity", + "AITaskEntityFeature", + "GenTextTask", + "GenTextTaskResult", + "async_generate_text", + "async_setup", + "async_setup_entry", + "async_unload_entry", +] + +_LOGGER = logging.getLogger(__name__) + +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Register the process service.""" + entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) + hass.data[DATA_COMPONENT] = entity_component + hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) + await hass.data[DATA_PREFERENCES].async_load() + async_setup_conversation_http(hass) + hass.services.async_register( + DOMAIN, + "generate_text", + async_service_generate_text, + schema=vol.Schema( + { + vol.Required("task_name"): cv.string, + vol.Optional("entity_id"): cv.entity_id, + vol.Required("instructions"): cv.string, + } + ), + supports_response=SupportsResponse.ONLY, + job_type=HassJobType.Coroutinefunction, + ) + return True + + +async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Set up a config entry.""" + return await hass.data[DATA_COMPONENT].async_setup_entry(entry) + + +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a config entry.""" + return await hass.data[DATA_COMPONENT].async_unload_entry(entry) + + +async def async_service_generate_text(call: ServiceCall) -> ServiceResponse: + """Run the run task service.""" + result = await async_generate_text(hass=call.hass, **call.data) + return result.as_dict() # type: ignore[return-value] + + +class AITaskPreferences: + """AI Task preferences.""" + + KEYS = ("gen_text_entity_id",) + + gen_text_entity_id: str | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the preferences.""" + self._store: storage.Store[dict[str, str | None]] = storage.Store( + hass, 1, DOMAIN + ) + + async def async_load(self) -> None: + """Load the data from the store.""" + data = await self._store.async_load() + if data is None: + return + for key in self.KEYS: + setattr(self, key, data[key]) + + @callback + def async_set_preferences( + self, + *, + gen_text_entity_id: str | None | UndefinedType = UNDEFINED, + ) -> None: + """Set the preferences.""" + changed = False + for key, value in (("gen_text_entity_id", gen_text_entity_id),): + if value is not UNDEFINED: + if getattr(self, key) != value: + setattr(self, key, value) + changed = True + + if not changed: + return + + self._store.async_delay_save(self.as_dict, 10) + + @callback + def as_dict(self) -> dict[str, str | None]: + """Get the current preferences.""" + return {key: getattr(self, key) for key in self.KEYS} diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py new file mode 100644 index 00000000000..69786178583 --- /dev/null +++ b/homeassistant/components/ai_task/const.py @@ -0,0 +1,29 @@ +"""Constants for the AI Task integration.""" + +from __future__ import annotations + +from enum import IntFlag +from typing import TYPE_CHECKING + +from homeassistant.util.hass_dict import HassKey + +if TYPE_CHECKING: + from homeassistant.helpers.entity_component import EntityComponent + + from . import AITaskPreferences + from .entity import AITaskEntity + +DOMAIN = "ai_task" +DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) +DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") + +DEFAULT_SYSTEM_PROMPT = ( + "You are a Home Assistant expert and help users with their tasks." +) + + +class AITaskEntityFeature(IntFlag): + """Supported features of the AI task entity.""" + + GENERATE_TEXT = 1 + """Generate text based on instructions.""" diff --git a/homeassistant/components/ai_task/entity.py b/homeassistant/components/ai_task/entity.py new file mode 100644 index 00000000000..88ce8144fb7 --- /dev/null +++ b/homeassistant/components/ai_task/entity.py @@ -0,0 +1,103 @@ +"""Entity for the AI Task integration.""" + +from collections.abc import AsyncGenerator +import contextlib +from typing import final + +from propcache.api import cached_property + +from homeassistant.components.conversation import ( + ChatLog, + UserContent, + async_get_chat_log, +) +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.helpers import llm +from homeassistant.helpers.chat_session import async_get_chat_session +from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.util import dt as dt_util + +from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature +from .task import GenTextTask, GenTextTaskResult + + +class AITaskEntity(RestoreEntity): + """Entity that supports conversations.""" + + _attr_should_poll = False + _attr_supported_features = AITaskEntityFeature(0) + __last_activity: str | None = None + + @property + @final + def state(self) -> str | None: + """Return the state of the entity.""" + if self.__last_activity is None: + return None + return self.__last_activity + + @cached_property + def supported_features(self) -> AITaskEntityFeature: + """Flag supported features.""" + return self._attr_supported_features + + async def async_internal_added_to_hass(self) -> None: + """Call when the entity is added to hass.""" + await super().async_internal_added_to_hass() + state = await self.async_get_last_state() + if ( + state is not None + and state.state is not None + and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) + ): + self.__last_activity = state.state + + @final + @contextlib.asynccontextmanager + async def _async_get_ai_task_chat_log( + self, + task: GenTextTask, + ) -> AsyncGenerator[ChatLog]: + """Context manager used to manage the ChatLog used during an AI Task.""" + # pylint: disable-next=contextmanager-generator-missing-cleanup + with ( + async_get_chat_session(self.hass) as session, + async_get_chat_log( + self.hass, + session, + None, + ) as chat_log, + ): + await chat_log.async_provide_llm_data( + llm.LLMContext( + platform=self.platform.domain, + context=None, + language=None, + assistant=DOMAIN, + device_id=None, + ), + user_llm_prompt=DEFAULT_SYSTEM_PROMPT, + ) + + chat_log.async_add_user_content(UserContent(task.instructions)) + + yield chat_log + + @final + async def internal_async_generate_text( + self, + task: GenTextTask, + ) -> GenTextTaskResult: + """Run a gen text task.""" + self.__last_activity = dt_util.utcnow().isoformat() + self.async_write_ha_state() + async with self._async_get_ai_task_chat_log(task) as chat_log: + return await self._async_generate_text(task, chat_log) + + async def _async_generate_text( + self, + task: GenTextTask, + chat_log: ChatLog, + ) -> GenTextTaskResult: + """Handle a gen text task.""" + raise NotImplementedError diff --git a/homeassistant/components/ai_task/http.py b/homeassistant/components/ai_task/http.py new file mode 100644 index 00000000000..6d44a4e8d3c --- /dev/null +++ b/homeassistant/components/ai_task/http.py @@ -0,0 +1,54 @@ +"""HTTP endpoint for AI Task integration.""" + +from typing import Any + +import voluptuous as vol + +from homeassistant.components import websocket_api +from homeassistant.core import HomeAssistant, callback + +from .const import DATA_PREFERENCES + + +@callback +def async_setup(hass: HomeAssistant) -> None: + """Set up the HTTP API for the conversation integration.""" + websocket_api.async_register_command(hass, websocket_get_preferences) + websocket_api.async_register_command(hass, websocket_set_preferences) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "ai_task/preferences/get", + } +) +@callback +def websocket_get_preferences( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Get AI task preferences.""" + preferences = hass.data[DATA_PREFERENCES] + connection.send_result(msg["id"], preferences.as_dict()) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "ai_task/preferences/set", + vol.Optional("gen_text_entity_id"): vol.Any(str, None), + } +) +@websocket_api.require_admin +@callback +def websocket_set_preferences( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Set AI task preferences.""" + preferences = hass.data[DATA_PREFERENCES] + msg.pop("type") + msg_id = msg.pop("id") + preferences.async_set_preferences(**msg) + connection.send_result(msg_id, preferences.as_dict()) diff --git a/homeassistant/components/ai_task/icons.json b/homeassistant/components/ai_task/icons.json new file mode 100644 index 00000000000..cb09e5c8f5d --- /dev/null +++ b/homeassistant/components/ai_task/icons.json @@ -0,0 +1,7 @@ +{ + "services": { + "generate_text": { + "service": "mdi:file-star-four-points-outline" + } + } +} diff --git a/homeassistant/components/ai_task/manifest.json b/homeassistant/components/ai_task/manifest.json new file mode 100644 index 00000000000..c685410530d --- /dev/null +++ b/homeassistant/components/ai_task/manifest.json @@ -0,0 +1,9 @@ +{ + "domain": "ai_task", + "name": "AI Task", + "codeowners": ["@home-assistant/core"], + "dependencies": ["conversation"], + "documentation": "https://www.home-assistant.io/integrations/ai_task", + "integration_type": "system", + "quality_scale": "internal" +} diff --git a/homeassistant/components/ai_task/services.yaml b/homeassistant/components/ai_task/services.yaml new file mode 100644 index 00000000000..32715bf77d7 --- /dev/null +++ b/homeassistant/components/ai_task/services.yaml @@ -0,0 +1,19 @@ +generate_text: + fields: + task_name: + example: "home summary" + required: true + selector: + text: + instructions: + example: "Generate a funny notification that garage door was left open" + required: true + selector: + text: + entity_id: + required: false + selector: + entity: + domain: ai_task + supported_features: + - ai_task.AITaskEntityFeature.GENERATE_TEXT diff --git a/homeassistant/components/ai_task/strings.json b/homeassistant/components/ai_task/strings.json new file mode 100644 index 00000000000..1cdbf20ba4f --- /dev/null +++ b/homeassistant/components/ai_task/strings.json @@ -0,0 +1,22 @@ +{ + "services": { + "generate_text": { + "name": "Generate text", + "description": "Use AI to run a task that generates text.", + "fields": { + "task_name": { + "name": "Task Name", + "description": "Name of the task." + }, + "instructions": { + "name": "Instructions", + "description": "Instructions on what needs to be done." + }, + "entity_id": { + "name": "Entity ID", + "description": "Entity ID to run the task on. If not provided, the preferred entity will be used." + } + } + } + } +} diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py new file mode 100644 index 00000000000..d0c59fdd09a --- /dev/null +++ b/homeassistant/components/ai_task/task.py @@ -0,0 +1,71 @@ +"""AI tasks to be handled by agents.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from homeassistant.core import HomeAssistant + +from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature + + +async def async_generate_text( + hass: HomeAssistant, + *, + task_name: str, + entity_id: str | None = None, + instructions: str, +) -> GenTextTaskResult: + """Run a task in the AI Task integration.""" + if entity_id is None: + entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id + + if entity_id is None: + raise ValueError("No entity_id provided and no preferred entity set") + + entity = hass.data[DATA_COMPONENT].get_entity(entity_id) + if entity is None: + raise ValueError(f"AI Task entity {entity_id} not found") + + if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features: + raise ValueError(f"AI Task entity {entity_id} does not support generating text") + + return await entity.internal_async_generate_text( + GenTextTask( + name=task_name, + instructions=instructions, + ) + ) + + +@dataclass(slots=True) +class GenTextTask: + """Gen text task to be processed.""" + + name: str + """Name of the task.""" + + instructions: str + """Instructions on what needs to be done.""" + + def __str__(self) -> str: + """Return task as a string.""" + return f"" + + +@dataclass(slots=True) +class GenTextTaskResult: + """Result of gen text task.""" + + conversation_id: str + """Unique identifier for the conversation.""" + + text: str + """Generated text.""" + + def as_dict(self) -> dict[str, str]: + """Return result as a dict.""" + return { + "conversation_id": self.conversation_id, + "text": self.text, + } diff --git a/homeassistant/const.py b/homeassistant/const.py index f692f428920..0abdcd59b77 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -40,6 +40,7 @@ PLATFORM_FORMAT: Final = "{platform}.{domain}" class Platform(StrEnum): """Available entity platforms.""" + AI_TASK = "ai_task" AIR_QUALITY = "air_quality" ALARM_CONTROL_PANEL = "alarm_control_panel" ASSIST_SATELLITE = "assist_satellite" diff --git a/tests/components/ai_task/__init__.py b/tests/components/ai_task/__init__.py new file mode 100644 index 00000000000..b4ca4688eb4 --- /dev/null +++ b/tests/components/ai_task/__init__.py @@ -0,0 +1 @@ +"""Tests for the AI Task integration.""" diff --git a/tests/components/ai_task/conftest.py b/tests/components/ai_task/conftest.py new file mode 100644 index 00000000000..2060c51bfa4 --- /dev/null +++ b/tests/components/ai_task/conftest.py @@ -0,0 +1,127 @@ +"""Test helpers for AI Task integration.""" + +import pytest + +from homeassistant.components.ai_task import ( + DOMAIN, + AITaskEntity, + AITaskEntityFeature, + GenTextTask, + GenTextTaskResult, +) +from homeassistant.components.conversation import AssistantContent, ChatLog +from homeassistant.config_entries import ConfigEntry, ConfigFlow +from homeassistant.const import Platform +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.setup import async_setup_component + +from tests.common import ( + MockConfigEntry, + MockModule, + MockPlatform, + mock_config_flow, + mock_integration, + mock_platform, +) + +TEST_DOMAIN = "test" +TEST_ENTITY_ID = "ai_task.test_task_entity" + + +class MockAITaskEntity(AITaskEntity): + """Mock AI Task entity for testing.""" + + _attr_name = "Test Task Entity" + _attr_supported_features = AITaskEntityFeature.GENERATE_TEXT + + def __init__(self) -> None: + """Initialize the mock entity.""" + super().__init__() + self.mock_generate_text_tasks = [] + + async def _async_generate_text( + self, task: GenTextTask, chat_log: ChatLog + ) -> GenTextTaskResult: + """Mock handling of generate text task.""" + self.mock_generate_text_tasks.append(task) + chat_log.async_add_assistant_content_without_tools( + AssistantContent(self.entity_id, "Mock result") + ) + return GenTextTaskResult( + conversation_id=chat_log.conversation_id, + text="Mock result", + ) + + +@pytest.fixture +def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: + """Mock a configuration entry for AI Task.""" + entry = MockConfigEntry(domain=TEST_DOMAIN, entry_id="mock-test-entry") + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +def mock_ai_task_entity( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> MockAITaskEntity: + """Mock AI Task entity.""" + return MockAITaskEntity() + + +@pytest.fixture +async def init_components( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_ai_task_entity: MockAITaskEntity, +): + """Initialize the AI Task integration with a mock entity.""" + assert await async_setup_component(hass, "homeassistant", {}) + + async def async_setup_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Set up test config entry.""" + await hass.config_entries.async_forward_entry_setups( + config_entry, [Platform.AI_TASK] + ) + return True + + async def async_unload_entry_init( + hass: HomeAssistant, config_entry: ConfigEntry + ) -> bool: + """Unload test config entry.""" + await hass.config_entries.async_forward_entry_unload( + config_entry, Platform.AI_TASK + ) + return True + + mock_integration( + hass, + MockModule( + TEST_DOMAIN, + async_setup_entry=async_setup_entry_init, + async_unload_entry=async_unload_entry_init, + ), + ) + + async def async_setup_entry_platform( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, + ) -> None: + """Set up test tts platform via config entry.""" + async_add_entities([mock_ai_task_entity]) + + mock_platform( + hass, + f"{TEST_DOMAIN}.{DOMAIN}", + MockPlatform(async_setup_entry=async_setup_entry_platform), + ) + + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + + with mock_config_flow(TEST_DOMAIN, ConfigFlow): + assert await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() diff --git a/tests/components/ai_task/snapshots/test_task.ambr b/tests/components/ai_task/snapshots/test_task.ambr new file mode 100644 index 00000000000..6d155c82a68 --- /dev/null +++ b/tests/components/ai_task/snapshots/test_task.ambr @@ -0,0 +1,22 @@ +# serializer version: 1 +# name: test_run_text_task_updates_chat_log + list([ + dict({ + 'content': ''' + You are a Home Assistant expert and help users with their tasks. + Current time is 15:59:00. Today's date is 2025-06-14. + ''', + 'role': 'system', + }), + dict({ + 'content': 'Test prompt', + 'role': 'user', + }), + dict({ + 'agent_id': 'ai_task.test_task_entity', + 'content': 'Mock result', + 'role': 'assistant', + 'tool_calls': None, + }), + ]) +# --- diff --git a/tests/components/ai_task/test_entity.py b/tests/components/ai_task/test_entity.py new file mode 100644 index 00000000000..aa9afbf6560 --- /dev/null +++ b/tests/components/ai_task/test_entity.py @@ -0,0 +1,39 @@ +"""Tests for the AI Task entity model.""" + +from freezegun import freeze_time + +from homeassistant.components.ai_task import async_generate_text +from homeassistant.const import STATE_UNKNOWN +from homeassistant.core import HomeAssistant + +from .conftest import TEST_ENTITY_ID, MockAITaskEntity + +from tests.common import MockConfigEntry + + +@freeze_time("2025-06-08 16:28:13") +async def test_state_generate_text( + hass: HomeAssistant, + init_components: None, + mock_config_entry: MockConfigEntry, + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test the state of the AI Task entity is updated when generating text.""" + entity = hass.states.get(TEST_ENTITY_ID) + assert entity is not None + assert entity.state == STATE_UNKNOWN + + result = await async_generate_text( + hass, + task_name="Test task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + ) + assert result.text == "Mock result" + + entity = hass.states.get(TEST_ENTITY_ID) + assert entity.state == "2025-06-08T16:28:13+00:00" + + assert mock_ai_task_entity.mock_generate_text_tasks + task = mock_ai_task_entity.mock_generate_text_tasks[0] + assert task.instructions == "Test prompt" diff --git a/tests/components/ai_task/test_http.py b/tests/components/ai_task/test_http.py new file mode 100644 index 00000000000..4436e1d45d5 --- /dev/null +++ b/tests/components/ai_task/test_http.py @@ -0,0 +1,84 @@ +"""Test the HTTP API for AI Task integration.""" + +from homeassistant.core import HomeAssistant + +from tests.typing import WebSocketGenerator + + +async def test_ws_preferences( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components: None, +) -> None: + """Test preferences via the WebSocket API.""" + client = await hass_ws_client(hass) + + # Get initial preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": None, + } + + # Set preferences + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_entity_id": "ai_task.summary_1", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_1", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_1", + } + + # Update an existing preference + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_entity_id": "ai_task.summary_2", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } + + # No preferences set will preserve existing preferences + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py new file mode 100644 index 00000000000..2f45d812b1f --- /dev/null +++ b/tests/components/ai_task/test_init.py @@ -0,0 +1,84 @@ +"""Test initialization of the AI Task component.""" + +from freezegun.api import FrozenDateTimeFactory +import pytest + +from homeassistant.components.ai_task import AITaskPreferences +from homeassistant.components.ai_task.const import DATA_PREFERENCES +from homeassistant.core import HomeAssistant + +from .conftest import TEST_ENTITY_ID + +from tests.common import flush_store + + +async def test_preferences_storage_load( + hass: HomeAssistant, +) -> None: + """Test that AITaskPreferences are stored and loaded correctly.""" + preferences = AITaskPreferences(hass) + await preferences.async_load() + + # Initial state should be None for entity IDs + for key in AITaskPreferences.KEYS: + assert getattr(preferences, key) is None, f"Initial {key} should be None" + + new_values = {key: f"ai_task.test_{key}" for key in AITaskPreferences.KEYS} + + preferences.async_set_preferences(**new_values) + + # Verify that current preferences object is updated + for key, value in new_values.items(): + assert getattr(preferences, key) == value, ( + f"Current {key} should match set value" + ) + + await flush_store(preferences._store) + + # Create a new preferences instance to test loading from store + new_preferences_instance = AITaskPreferences(hass) + await new_preferences_instance.async_load() + + for key in AITaskPreferences.KEYS: + assert getattr(preferences, key) == getattr(new_preferences_instance, key), ( + f"Loaded {key} should match saved value" + ) + + +@pytest.mark.parametrize( + ("set_preferences", "msg_extra"), + [ + ( + {"gen_text_entity_id": TEST_ENTITY_ID}, + {}, + ), + ( + {}, + {"entity_id": TEST_ENTITY_ID}, + ), + ], +) +async def test_generate_text_service( + hass: HomeAssistant, + init_components: None, + freezer: FrozenDateTimeFactory, + set_preferences: dict[str, str | None], + msg_extra: dict[str, str], +) -> None: + """Test the generate text service.""" + preferences = hass.data[DATA_PREFERENCES] + preferences.async_set_preferences(**set_preferences) + + result = await hass.services.async_call( + "ai_task", + "generate_text", + { + "task_name": "Test Name", + "instructions": "Test prompt", + } + | msg_extra, + blocking=True, + return_response=True, + ) + + assert result["text"] == "Mock result" diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py new file mode 100644 index 00000000000..d4df66d83f9 --- /dev/null +++ b/tests/components/ai_task/test_task.py @@ -0,0 +1,123 @@ +"""Test tasks for the AI Task integration.""" + +from freezegun import freeze_time +import pytest +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_text +from homeassistant.components.conversation import async_get_chat_log +from homeassistant.const import STATE_UNKNOWN +from homeassistant.core import HomeAssistant +from homeassistant.helpers import chat_session + +from .conftest import TEST_ENTITY_ID, MockAITaskEntity + +from tests.typing import WebSocketGenerator + + +async def test_run_task_preferred_entity( + hass: HomeAssistant, + init_components: None, + mock_ai_task_entity: MockAITaskEntity, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test running a task with an unknown entity.""" + client = await hass_ws_client(hass) + + with pytest.raises( + ValueError, match="No entity_id provided and no preferred entity set" + ): + await async_generate_text( + hass, + task_name="Test Task", + instructions="Test prompt", + ) + + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_entity_id": "ai_task.unknown", + } + ) + msg = await client.receive_json() + assert msg["success"] + + with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"): + await async_generate_text( + hass, + task_name="Test Task", + instructions="Test prompt", + ) + + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_entity_id": TEST_ENTITY_ID, + } + ) + msg = await client.receive_json() + assert msg["success"] + + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state == STATE_UNKNOWN + + result = await async_generate_text( + hass, + task_name="Test Task", + instructions="Test prompt", + ) + assert result.text == "Mock result" + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state != STATE_UNKNOWN + + mock_ai_task_entity.supported_features = AITaskEntityFeature(0) + with pytest.raises( + ValueError, + match="AI Task entity ai_task.test_task_entity does not support generating text", + ): + await async_generate_text( + hass, + task_name="Test Task", + instructions="Test prompt", + ) + + +async def test_run_text_task_unknown_entity( + hass: HomeAssistant, + init_components: None, +) -> None: + """Test running a text task with an unknown entity.""" + + with pytest.raises( + ValueError, match="AI Task entity ai_task.unknown_entity not found" + ): + await async_generate_text( + hass, + task_name="Test Task", + entity_id="ai_task.unknown_entity", + instructions="Test prompt", + ) + + +@freeze_time("2025-06-14 22:59:00") +async def test_run_text_task_updates_chat_log( + hass: HomeAssistant, + init_components: None, + snapshot: SnapshotAssertion, +) -> None: + """Test that running a text task updates the chat log.""" + result = await async_generate_text( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + ) + assert result.text == "Mock result" + + with ( + chat_session.async_get_chat_session(hass, result.conversation_id) as session, + async_get_chat_log(hass, session) as chat_log, + ): + assert chat_log.content == snapshot