mirror of
https://github.com/home-assistant/core.git
synced 2025-07-29 16:17:20 +00:00
Add AI Task prefs
This commit is contained in:
parent
a8d4caab01
commit
17a5815ca1
@ -3,12 +3,12 @@
|
||||
import logging
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, storage
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||
|
||||
from .const import DATA_COMPONENT, DOMAIN
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .task import GenTextTask, GenTextTaskResult, async_generate_text
|
||||
@ -33,6 +33,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
|
||||
hass.data[DATA_COMPONENT] = entity_component
|
||||
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
||||
await hass.data[DATA_PREFERENCES].async_load()
|
||||
async_setup_conversation_http(hass)
|
||||
return True
|
||||
|
||||
@ -45,3 +47,53 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
||||
|
||||
|
||||
class AITaskPreferences:
|
||||
"""AI Task preferences."""
|
||||
|
||||
gen_text_entity_id: str | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the preferences."""
|
||||
self._store: storage.Store[dict[str, str | None]] = storage.Store(
|
||||
hass, 1, DOMAIN
|
||||
)
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the data from the store."""
|
||||
data = await self._store.async_load()
|
||||
if data is None:
|
||||
return
|
||||
self.gen_text_entity_id = data.get("gen_text_entity_id")
|
||||
|
||||
@callback
|
||||
def async_set_preferences(
|
||||
self,
|
||||
*,
|
||||
gen_text_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Set the preferences."""
|
||||
changed = False
|
||||
for key, value in (("gen_text_entity_id", gen_text_entity_id),):
|
||||
if value is not UNDEFINED:
|
||||
if getattr(self, key) != value:
|
||||
setattr(self, key, value)
|
||||
changed = True
|
||||
|
||||
if not changed:
|
||||
return
|
||||
|
||||
self._store.async_delay_save(
|
||||
lambda: {
|
||||
"gen_text_entity_id": self.gen_text_entity_id,
|
||||
},
|
||||
10,
|
||||
)
|
||||
|
||||
@callback
|
||||
def as_dict(self) -> dict[str, str | None]:
|
||||
"""Get the current preferences."""
|
||||
return {
|
||||
"gen_text_entity_id": self.gen_text_entity_id,
|
||||
}
|
||||
|
@ -9,10 +9,12 @@ from homeassistant.util.hass_dict import HassKey
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from . import AITaskPreferences
|
||||
from .entity import AITaskEntity
|
||||
|
||||
DOMAIN = "ai_task"
|
||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a Home Assistant expert and help users with their tasks."
|
||||
|
@ -7,6 +7,7 @@ import voluptuous as vol
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from .const import DATA_PREFERENCES
|
||||
from .task import async_generate_text
|
||||
|
||||
|
||||
@ -14,13 +15,15 @@ from .task import async_generate_text
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the HTTP API for the conversation integration."""
|
||||
websocket_api.async_register_command(hass, websocket_generate_text)
|
||||
websocket_api.async_register_command(hass, websocket_get_preferences)
|
||||
websocket_api.async_register_command(hass, websocket_set_preferences)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "ai_task/generate_text",
|
||||
vol.Required("task_name"): str,
|
||||
vol.Required("entity_id"): str,
|
||||
vol.Optional("entity_id"): str,
|
||||
vol.Required("instructions"): str,
|
||||
}
|
||||
)
|
||||
@ -34,5 +37,46 @@ async def websocket_generate_text(
|
||||
"""Run a generate text task."""
|
||||
msg.pop("type")
|
||||
msg_id = msg.pop("id")
|
||||
result = await async_generate_text(hass=hass, **msg)
|
||||
try:
|
||||
result = await async_generate_text(hass=hass, **msg)
|
||||
except ValueError as err:
|
||||
connection.send_error(msg_id, websocket_api.const.ERR_UNKNOWN_ERROR, str(err))
|
||||
return
|
||||
connection.send_result(msg_id, result.as_dict())
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "ai_task/preferences/get",
|
||||
}
|
||||
)
|
||||
@callback
|
||||
def websocket_get_preferences(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Get AI task preferences."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
connection.send_result(msg["id"], preferences.as_dict())
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "ai_task/preferences/set",
|
||||
vol.Optional("gen_text_entity_id"): vol.Any(str, None),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
@callback
|
||||
def websocket_set_preferences(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set AI task preferences."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
msg.pop("type")
|
||||
msg_id = msg.pop("id")
|
||||
preferences.async_set_preferences(**msg)
|
||||
connection.send_result(msg_id, preferences.as_dict())
|
||||
|
@ -6,17 +6,23 @@ from dataclasses import dataclass
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DATA_COMPONENT
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES
|
||||
|
||||
|
||||
async def async_generate_text(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
task_name: str,
|
||||
entity_id: str,
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
) -> GenTextTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
|
||||
|
||||
if entity_id is None:
|
||||
raise ValueError("No entity_id provided and no preferred entity set")
|
||||
|
||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||
if entity is None:
|
||||
raise ValueError(f"AI Task entity {entity_id} not found")
|
||||
|
@ -1,5 +1,8 @@
|
||||
"""Test the HTTP API for AI Task integration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.ai_task import DATA_PREFERENCES
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
@ -8,12 +11,21 @@ from .conftest import TEST_ENTITY_ID
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msg_extra",
|
||||
[
|
||||
{},
|
||||
{"entity_id": TEST_ENTITY_ID},
|
||||
],
|
||||
)
|
||||
async def test_ws_generate_text(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components: None,
|
||||
msg_extra: dict,
|
||||
) -> None:
|
||||
"""Test running a generate text task via the WebSocket API."""
|
||||
hass.data[DATA_PREFERENCES].async_set_preferences(gen_text_entity_id=TEST_ENTITY_ID)
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity is not None
|
||||
assert entity.state == STATE_UNKNOWN
|
||||
@ -24,9 +36,9 @@ async def test_ws_generate_text(
|
||||
{
|
||||
"type": "ai_task/generate_text",
|
||||
"task_name": "Test Task",
|
||||
"entity_id": TEST_ENTITY_ID,
|
||||
"instructions": "Test prompt",
|
||||
}
|
||||
| msg_extra
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
@ -36,3 +48,82 @@ async def test_ws_generate_text(
|
||||
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity.state != STATE_UNKNOWN
|
||||
|
||||
|
||||
async def test_ws_preferences(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test preferences via the WebSocket API."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
# Get initial preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": None,
|
||||
}
|
||||
|
||||
# Set preferences
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
|
||||
# Set only one preference
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# Clear a preference
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
await client.send_json_auto_id({"type": "ai_task/preferences/get"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
53
tests/components/ai_task/test_init.py
Normal file
53
tests/components/ai_task/test_init.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Test initialization of the AI Task component."""
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
|
||||
from homeassistant.components.ai_task import AITaskPreferences
|
||||
from homeassistant.components.ai_task.const import DATA_PREFERENCES
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.common import flush_store
|
||||
|
||||
|
||||
async def test_preferences_storage_load(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
) -> None:
|
||||
"""Test that AITaskPreferences are stored and loaded correctly."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
|
||||
# Initial state should be None for entity IDs
|
||||
assert preferences.gen_text_entity_id is None
|
||||
|
||||
gen_text_id_1 = "sensor.summary_one"
|
||||
|
||||
preferences.async_set_preferences(
|
||||
gen_text_entity_id=gen_text_id_1,
|
||||
)
|
||||
|
||||
# Verify that current preferences object is updated
|
||||
assert preferences.gen_text_entity_id == gen_text_id_1
|
||||
|
||||
await flush_store(preferences._store)
|
||||
|
||||
# Create a new preferences instance to test loading from store
|
||||
new_preferences_instance = AITaskPreferences(hass)
|
||||
await new_preferences_instance.async_load()
|
||||
|
||||
assert new_preferences_instance.gen_text_entity_id == gen_text_id_1
|
||||
|
||||
# Test updating one preference and setting another to None
|
||||
gen_text_id_2 = "sensor.summary_two"
|
||||
preferences.async_set_preferences(gen_text_entity_id=gen_text_id_2)
|
||||
|
||||
# Verify that current preferences object is updated
|
||||
assert preferences.gen_text_entity_id == gen_text_id_2
|
||||
|
||||
await flush_store(preferences._store)
|
||||
|
||||
# Create another new preferences instance to confirm persistence of the update
|
||||
another_new_preferences_instance = AITaskPreferences(hass)
|
||||
await another_new_preferences_instance.async_load()
|
||||
|
||||
assert another_new_preferences_instance.gen_text_entity_id == gen_text_id_2
|
@ -4,14 +4,56 @@ from freezegun import freeze_time
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.ai_task import async_generate_text
|
||||
from homeassistant.components.ai_task import DATA_PREFERENCES, async_generate_text
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import chat_session
|
||||
|
||||
from .conftest import TEST_ENTITY_ID
|
||||
|
||||
|
||||
async def test_run_task_preferred_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test running a task with an unknown entity."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No entity_id provided and no preferred entity set"
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
preferences.async_set_preferences(gen_text_entity_id="ai_task.unknown")
|
||||
|
||||
with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
preferences.async_set_preferences(gen_text_entity_id=TEST_ENTITY_ID)
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.result == "Mock result"
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state != STATE_UNKNOWN
|
||||
|
||||
|
||||
async def test_run_text_task_unknown_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user