diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index 5d8082e06b3..2021a92d6e2 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -3,12 +3,12 @@ import logging from homeassistant.config_entries import ConfigEntry -from homeassistant.core import HomeAssistant -from homeassistant.helpers import config_validation as cv +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import config_validation as cv, storage from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType -from .const import DATA_COMPONENT, DOMAIN +from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN from .entity import AITaskEntity from .http import async_setup as async_setup_conversation_http from .task import GenTextTask, GenTextTaskResult, async_generate_text @@ -33,6 +33,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Register the process service.""" entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) hass.data[DATA_COMPONENT] = entity_component + hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) + await hass.data[DATA_PREFERENCES].async_load() async_setup_conversation_http(hass) return True @@ -45,3 +47,53 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" return await hass.data[DATA_COMPONENT].async_unload_entry(entry) + + +class AITaskPreferences: + """AI Task preferences.""" + + gen_text_entity_id: str | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the preferences.""" + self._store: storage.Store[dict[str, str | None]] = storage.Store( + hass, 1, DOMAIN + ) + + async def async_load(self) -> None: + """Load the data from the store.""" + data = await self._store.async_load() + if data is None: + return + self.gen_text_entity_id = data.get("gen_text_entity_id") + + @callback + def async_set_preferences( + self, + *, + gen_text_entity_id: str | None | UndefinedType = UNDEFINED, + ) -> None: + """Set the preferences.""" + changed = False + for key, value in (("gen_text_entity_id", gen_text_entity_id),): + if value is not UNDEFINED: + if getattr(self, key) != value: + setattr(self, key, value) + changed = True + + if not changed: + return + + self._store.async_delay_save( + lambda: { + "gen_text_entity_id": self.gen_text_entity_id, + }, + 10, + ) + + @callback + def as_dict(self) -> dict[str, str | None]: + """Get the current preferences.""" + return { + "gen_text_entity_id": self.gen_text_entity_id, + } diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py index 7c5b4e281d2..03809da5f4a 100644 --- a/homeassistant/components/ai_task/const.py +++ b/homeassistant/components/ai_task/const.py @@ -9,10 +9,12 @@ from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: from homeassistant.helpers.entity_component import EntityComponent + from . import AITaskPreferences from .entity import AITaskEntity DOMAIN = "ai_task" DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) +DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") DEFAULT_SYSTEM_PROMPT = ( "You are a Home Assistant expert and help users with their tasks." diff --git a/homeassistant/components/ai_task/http.py b/homeassistant/components/ai_task/http.py index 22a603662ee..e82d71a218f 100644 --- a/homeassistant/components/ai_task/http.py +++ b/homeassistant/components/ai_task/http.py @@ -7,6 +7,7 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.core import HomeAssistant, callback +from .const import DATA_PREFERENCES from .task import async_generate_text @@ -14,13 +15,15 @@ from .task import async_generate_text def async_setup(hass: HomeAssistant) -> None: """Set up the HTTP API for the conversation integration.""" websocket_api.async_register_command(hass, websocket_generate_text) + websocket_api.async_register_command(hass, websocket_get_preferences) + websocket_api.async_register_command(hass, websocket_set_preferences) @websocket_api.websocket_command( { vol.Required("type"): "ai_task/generate_text", vol.Required("task_name"): str, - vol.Required("entity_id"): str, + vol.Optional("entity_id"): str, vol.Required("instructions"): str, } ) @@ -34,5 +37,46 @@ async def websocket_generate_text( """Run a generate text task.""" msg.pop("type") msg_id = msg.pop("id") - result = await async_generate_text(hass=hass, **msg) + try: + result = await async_generate_text(hass=hass, **msg) + except ValueError as err: + connection.send_error(msg_id, websocket_api.const.ERR_UNKNOWN_ERROR, str(err)) + return connection.send_result(msg_id, result.as_dict()) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "ai_task/preferences/get", + } +) +@callback +def websocket_get_preferences( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Get AI task preferences.""" + preferences = hass.data[DATA_PREFERENCES] + connection.send_result(msg["id"], preferences.as_dict()) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "ai_task/preferences/set", + vol.Optional("gen_text_entity_id"): vol.Any(str, None), + } +) +@websocket_api.require_admin +@callback +def websocket_set_preferences( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Set AI task preferences.""" + preferences = hass.data[DATA_PREFERENCES] + msg.pop("type") + msg_id = msg.pop("id") + preferences.async_set_preferences(**msg) + connection.send_result(msg_id, preferences.as_dict()) diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index 0a286cb7fcf..e326d83c9d7 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -6,17 +6,23 @@ from dataclasses import dataclass from homeassistant.core import HomeAssistant -from .const import DATA_COMPONENT +from .const import DATA_COMPONENT, DATA_PREFERENCES async def async_generate_text( hass: HomeAssistant, *, task_name: str, - entity_id: str, + entity_id: str | None = None, instructions: str, ) -> GenTextTaskResult: """Run a task in the AI Task integration.""" + if entity_id is None: + entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id + + if entity_id is None: + raise ValueError("No entity_id provided and no preferred entity set") + entity = hass.data[DATA_COMPONENT].get_entity(entity_id) if entity is None: raise ValueError(f"AI Task entity {entity_id} not found") diff --git a/tests/components/ai_task/test_http.py b/tests/components/ai_task/test_http.py index ac3395e84b0..c2e84f86f98 100644 --- a/tests/components/ai_task/test_http.py +++ b/tests/components/ai_task/test_http.py @@ -1,5 +1,8 @@ """Test the HTTP API for AI Task integration.""" +import pytest + +from homeassistant.components.ai_task import DATA_PREFERENCES from homeassistant.const import STATE_UNKNOWN from homeassistant.core import HomeAssistant @@ -8,12 +11,21 @@ from .conftest import TEST_ENTITY_ID from tests.typing import WebSocketGenerator +@pytest.mark.parametrize( + "msg_extra", + [ + {}, + {"entity_id": TEST_ENTITY_ID}, + ], +) async def test_ws_generate_text( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components: None, + msg_extra: dict, ) -> None: """Test running a generate text task via the WebSocket API.""" + hass.data[DATA_PREFERENCES].async_set_preferences(gen_text_entity_id=TEST_ENTITY_ID) entity = hass.states.get(TEST_ENTITY_ID) assert entity is not None assert entity.state == STATE_UNKNOWN @@ -24,9 +36,9 @@ async def test_ws_generate_text( { "type": "ai_task/generate_text", "task_name": "Test Task", - "entity_id": TEST_ENTITY_ID, "instructions": "Test prompt", } + | msg_extra ) msg = await client.receive_json() @@ -36,3 +48,82 @@ async def test_ws_generate_text( entity = hass.states.get(TEST_ENTITY_ID) assert entity.state != STATE_UNKNOWN + + +async def test_ws_preferences( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + init_components: None, +) -> None: + """Test preferences via the WebSocket API.""" + client = await hass_ws_client(hass) + + # Get initial preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": None, + } + + # Set preferences + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_entity_id": "ai_task.summary_1", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_1", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_1", + } + + # Set only one preference + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + "gen_text_entity_id": "ai_task.summary_2", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } + + # Clear a preference + await client.send_json_auto_id( + { + "type": "ai_task/preferences/set", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } + + # Get updated preferences + await client.send_json_auto_id({"type": "ai_task/preferences/get"}) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "gen_text_entity_id": "ai_task.summary_2", + } diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py new file mode 100644 index 00000000000..2498ee2680a --- /dev/null +++ b/tests/components/ai_task/test_init.py @@ -0,0 +1,53 @@ +"""Test initialization of the AI Task component.""" + +from freezegun.api import FrozenDateTimeFactory + +from homeassistant.components.ai_task import AITaskPreferences +from homeassistant.components.ai_task.const import DATA_PREFERENCES +from homeassistant.core import HomeAssistant + +from tests.common import flush_store + + +async def test_preferences_storage_load( + hass: HomeAssistant, + init_components: None, + freezer: FrozenDateTimeFactory, +) -> None: + """Test that AITaskPreferences are stored and loaded correctly.""" + preferences = hass.data[DATA_PREFERENCES] + + # Initial state should be None for entity IDs + assert preferences.gen_text_entity_id is None + + gen_text_id_1 = "sensor.summary_one" + + preferences.async_set_preferences( + gen_text_entity_id=gen_text_id_1, + ) + + # Verify that current preferences object is updated + assert preferences.gen_text_entity_id == gen_text_id_1 + + await flush_store(preferences._store) + + # Create a new preferences instance to test loading from store + new_preferences_instance = AITaskPreferences(hass) + await new_preferences_instance.async_load() + + assert new_preferences_instance.gen_text_entity_id == gen_text_id_1 + + # Test updating one preference and setting another to None + gen_text_id_2 = "sensor.summary_two" + preferences.async_set_preferences(gen_text_entity_id=gen_text_id_2) + + # Verify that current preferences object is updated + assert preferences.gen_text_entity_id == gen_text_id_2 + + await flush_store(preferences._store) + + # Create another new preferences instance to confirm persistence of the update + another_new_preferences_instance = AITaskPreferences(hass) + await another_new_preferences_instance.async_load() + + assert another_new_preferences_instance.gen_text_entity_id == gen_text_id_2 diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py index 56a1ff20a6a..7a9d01e867d 100644 --- a/tests/components/ai_task/test_task.py +++ b/tests/components/ai_task/test_task.py @@ -4,14 +4,56 @@ from freezegun import freeze_time import pytest from syrupy.assertion import SnapshotAssertion -from homeassistant.components.ai_task import async_generate_text +from homeassistant.components.ai_task import DATA_PREFERENCES, async_generate_text from homeassistant.components.conversation import async_get_chat_log +from homeassistant.const import STATE_UNKNOWN from homeassistant.core import HomeAssistant from homeassistant.helpers import chat_session from .conftest import TEST_ENTITY_ID +async def test_run_task_preferred_entity( + hass: HomeAssistant, + init_components: None, +) -> None: + """Test running a task with an unknown entity.""" + preferences = hass.data[DATA_PREFERENCES] + + with pytest.raises( + ValueError, match="No entity_id provided and no preferred entity set" + ): + await async_generate_text( + hass, + task_name="Test Task", + instructions="Test prompt", + ) + + preferences.async_set_preferences(gen_text_entity_id="ai_task.unknown") + + with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"): + await async_generate_text( + hass, + task_name="Test Task", + instructions="Test prompt", + ) + + preferences.async_set_preferences(gen_text_entity_id=TEST_ENTITY_ID) + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state == STATE_UNKNOWN + + result = await async_generate_text( + hass, + task_name="Test Task", + instructions="Test prompt", + ) + assert result.result == "Mock result" + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state != STATE_UNKNOWN + + async def test_run_text_task_unknown_entity( hass: HomeAssistant, init_components: None,