AI task generate_text -> generate_data (#147370)

This commit is contained in:
Paulus Schoutsen 2025-06-24 07:12:29 -04:00 committed by GitHub
parent 38c7eaf70a
commit 63ac14a19b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 104 additions and 100 deletions

View File

@ -24,20 +24,20 @@ from .const import (
DATA_COMPONENT,
DATA_PREFERENCES,
DOMAIN,
SERVICE_GENERATE_TEXT,
SERVICE_GENERATE_DATA,
AITaskEntityFeature,
)
from .entity import AITaskEntity
from .http import async_setup as async_setup_http
from .task import GenTextTask, GenTextTaskResult, async_generate_text
from .task import GenDataTask, GenDataTaskResult, async_generate_data
__all__ = [
"DOMAIN",
"AITaskEntity",
"AITaskEntityFeature",
"GenTextTask",
"GenTextTaskResult",
"async_generate_text",
"GenDataTask",
"GenDataTaskResult",
"async_generate_data",
"async_setup",
"async_setup_entry",
"async_unload_entry",
@ -57,8 +57,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_setup_http(hass)
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_TEXT,
async_service_generate_text,
SERVICE_GENERATE_DATA,
async_service_generate_data,
schema=vol.Schema(
{
vol.Required(ATTR_TASK_NAME): cv.string,
@ -82,18 +82,18 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
async def async_service_generate_text(call: ServiceCall) -> ServiceResponse:
async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
"""Run the run task service."""
result = await async_generate_text(hass=call.hass, **call.data)
return result.as_dict() # type: ignore[return-value]
result = await async_generate_data(hass=call.hass, **call.data)
return result.as_dict()
class AITaskPreferences:
"""AI Task preferences."""
KEYS = ("gen_text_entity_id",)
KEYS = ("gen_data_entity_id",)
gen_text_entity_id: str | None = None
gen_data_entity_id: str | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the preferences."""
@ -113,11 +113,11 @@ class AITaskPreferences:
def async_set_preferences(
self,
*,
gen_text_entity_id: str | None | UndefinedType = UNDEFINED,
gen_data_entity_id: str | None | UndefinedType = UNDEFINED,
) -> None:
"""Set the preferences."""
changed = False
for key, value in (("gen_text_entity_id", gen_text_entity_id),):
for key, value in (("gen_data_entity_id", gen_data_entity_id),):
if value is not UNDEFINED:
if getattr(self, key) != value:
setattr(self, key, value)

View File

@ -17,7 +17,7 @@ DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
SERVICE_GENERATE_TEXT = "generate_text"
SERVICE_GENERATE_DATA = "generate_data"
ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name"
@ -30,5 +30,5 @@ DEFAULT_SYSTEM_PROMPT = (
class AITaskEntityFeature(IntFlag):
"""Supported features of the AI task entity."""
GENERATE_TEXT = 1
"""Generate text based on instructions."""
GENERATE_DATA = 1
"""Generate data based on instructions."""

View File

@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
from .task import GenTextTask, GenTextTaskResult
from .task import GenDataTask, GenDataTaskResult
class AITaskEntity(RestoreEntity):
@ -56,7 +56,7 @@ class AITaskEntity(RestoreEntity):
@contextlib.asynccontextmanager
async def _async_get_ai_task_chat_log(
self,
task: GenTextTask,
task: GenDataTask,
) -> AsyncGenerator[ChatLog]:
"""Context manager used to manage the ChatLog used during an AI Task."""
# pylint: disable-next=contextmanager-generator-missing-cleanup
@ -84,20 +84,20 @@ class AITaskEntity(RestoreEntity):
yield chat_log
@final
async def internal_async_generate_text(
async def internal_async_generate_data(
self,
task: GenTextTask,
) -> GenTextTaskResult:
"""Run a gen text task."""
task: GenDataTask,
) -> GenDataTaskResult:
"""Run a gen data task."""
self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state()
async with self._async_get_ai_task_chat_log(task) as chat_log:
return await self._async_generate_text(task, chat_log)
return await self._async_generate_data(task, chat_log)
async def _async_generate_text(
async def _async_generate_data(
self,
task: GenTextTask,
task: GenDataTask,
chat_log: ChatLog,
) -> GenTextTaskResult:
"""Handle a gen text task."""
) -> GenDataTaskResult:
"""Handle a gen data task."""
raise NotImplementedError

View File

@ -36,7 +36,7 @@ def websocket_get_preferences(
@websocket_api.websocket_command(
{
vol.Required("type"): "ai_task/preferences/set",
vol.Optional("gen_text_entity_id"): vol.Any(str, None),
vol.Optional("gen_data_entity_id"): vol.Any(str, None),
}
)
@websocket_api.require_admin

View File

@ -1,6 +1,6 @@
{
"services": {
"generate_text": {
"generate_data": {
"service": "mdi:file-star-four-points-outline"
}
}

View File

@ -1,4 +1,4 @@
generate_text:
generate_data:
fields:
task_name:
example: "home summary"
@ -16,4 +16,4 @@ generate_text:
entity:
domain: ai_task
supported_features:
- ai_task.AITaskEntityFeature.GENERATE_TEXT
- ai_task.AITaskEntityFeature.GENERATE_DATA

View File

@ -1,8 +1,8 @@
{
"services": {
"generate_text": {
"name": "Generate text",
"description": "Use AI to run a task that generates text.",
"generate_data": {
"name": "Generate data",
"description": "Uses AI to run a task that generates data.",
"fields": {
"task_name": {
"name": "Task name",

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@ -10,16 +11,16 @@ from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
async def async_generate_text(
async def async_generate_data(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
) -> GenTextTaskResult:
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
@ -28,13 +29,13 @@ async def async_generate_text(
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating text"
f"AI Task entity {entity_id} does not support generating data"
)
return await entity.internal_async_generate_text(
GenTextTask(
return await entity.internal_async_generate_data(
GenDataTask(
name=task_name,
instructions=instructions,
)
@ -42,8 +43,8 @@ async def async_generate_text(
@dataclass(slots=True)
class GenTextTask:
"""Gen text task to be processed."""
class GenDataTask:
"""Gen data task to be processed."""
name: str
"""Name of the task."""
@ -53,22 +54,22 @@ class GenTextTask:
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenTextTask {self.name}: {id(self)}>"
return f"<GenDataTask {self.name}: {id(self)}>"
@dataclass(slots=True)
class GenTextTaskResult:
"""Result of gen text task."""
class GenDataTaskResult:
"""Result of gen data task."""
conversation_id: str
"""Unique identifier for the conversation."""
text: str
"""Generated text."""
data: Any
"""Data generated by the task."""
def as_dict(self) -> dict[str, str]:
def as_dict(self) -> dict[str, Any]:
"""Return result as a dict."""
return {
"conversation_id": self.conversation_id,
"text": self.text,
"data": self.data,
}

View File

@ -6,8 +6,8 @@ from homeassistant.components.ai_task import (
DOMAIN,
AITaskEntity,
AITaskEntityFeature,
GenTextTask,
GenTextTaskResult,
GenDataTask,
GenDataTaskResult,
)
from homeassistant.components.conversation import AssistantContent, ChatLog
from homeassistant.config_entries import ConfigEntry, ConfigFlow
@ -33,24 +33,24 @@ class MockAITaskEntity(AITaskEntity):
"""Mock AI Task entity for testing."""
_attr_name = "Test Task Entity"
_attr_supported_features = AITaskEntityFeature.GENERATE_TEXT
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
def __init__(self) -> None:
"""Initialize the mock entity."""
super().__init__()
self.mock_generate_text_tasks = []
self.mock_generate_data_tasks = []
async def _async_generate_text(
self, task: GenTextTask, chat_log: ChatLog
) -> GenTextTaskResult:
"""Mock handling of generate text task."""
self.mock_generate_text_tasks.append(task)
async def _async_generate_data(
self, task: GenDataTask, chat_log: ChatLog
) -> GenDataTaskResult:
"""Mock handling of generate data task."""
self.mock_generate_data_tasks.append(task)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "Mock result")
)
return GenTextTaskResult(
return GenDataTaskResult(
conversation_id=chat_log.conversation_id,
text="Mock result",
data="Mock result",
)

View File

@ -1,5 +1,5 @@
# serializer version: 1
# name: test_run_text_task_updates_chat_log
# name: test_run_data_task_updates_chat_log
list([
dict({
'content': '''

View File

@ -2,7 +2,7 @@
from freezegun import freeze_time
from homeassistant.components.ai_task import async_generate_text
from homeassistant.components.ai_task import async_generate_data
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
@ -12,28 +12,28 @@ from tests.common import MockConfigEntry
@freeze_time("2025-06-08 16:28:13")
async def test_state_generate_text(
async def test_state_generate_data(
hass: HomeAssistant,
init_components: None,
mock_config_entry: MockConfigEntry,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the state of the AI Task entity is updated when generating text."""
"""Test the state of the AI Task entity is updated when generating data."""
entity = hass.states.get(TEST_ENTITY_ID)
assert entity is not None
assert entity.state == STATE_UNKNOWN
result = await async_generate_text(
result = await async_generate_data(
hass,
task_name="Test task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert result.text == "Mock result"
assert result.data == "Mock result"
entity = hass.states.get(TEST_ENTITY_ID)
assert entity.state == "2025-06-08T16:28:13+00:00"
assert mock_ai_task_entity.mock_generate_text_tasks
task = mock_ai_task_entity.mock_generate_text_tasks[0]
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Test prompt"

View File

@ -18,20 +18,20 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": None,
"gen_data_entity_id": None,
}
# Set preferences
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.summary_1",
"gen_data_entity_id": "ai_task.summary_1",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_1",
"gen_data_entity_id": "ai_task.summary_1",
}
# Get updated preferences
@ -39,20 +39,20 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_1",
"gen_data_entity_id": "ai_task.summary_1",
}
# Update an existing preference
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
# Get updated preferences
@ -60,7 +60,7 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
# No preferences set will preserve existing preferences
@ -72,7 +72,7 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
# Get updated preferences
@ -80,5 +80,5 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}

View File

@ -49,7 +49,7 @@ async def test_preferences_storage_load(
("set_preferences", "msg_extra"),
[
(
{"gen_text_entity_id": TEST_ENTITY_ID},
{"gen_data_entity_id": TEST_ENTITY_ID},
{},
),
(
@ -58,20 +58,20 @@ async def test_preferences_storage_load(
),
],
)
async def test_generate_text_service(
async def test_generate_data_service(
hass: HomeAssistant,
init_components: None,
freezer: FrozenDateTimeFactory,
set_preferences: dict[str, str | None],
msg_extra: dict[str, str],
) -> None:
"""Test the generate text service."""
"""Test the generate data service."""
preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences)
result = await hass.services.async_call(
"ai_task",
"generate_text",
"generate_data",
{
"task_name": "Test Name",
"instructions": "Test prompt",
@ -81,4 +81,4 @@ async def test_generate_text_service(
return_response=True,
)
assert result["text"] == "Mock result"
assert result["data"] == "Mock result"

View File

@ -4,7 +4,7 @@ from freezegun import freeze_time
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_text
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
@ -28,7 +28,7 @@ async def test_run_task_preferred_entity(
with pytest.raises(
HomeAssistantError, match="No entity_id provided and no preferred entity set"
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
@ -37,7 +37,7 @@ async def test_run_task_preferred_entity(
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.unknown",
"gen_data_entity_id": "ai_task.unknown",
}
)
msg = await client.receive_json()
@ -46,7 +46,7 @@ async def test_run_task_preferred_entity(
with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
@ -55,7 +55,7 @@ async def test_run_task_preferred_entity(
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": TEST_ENTITY_ID,
"gen_data_entity_id": TEST_ENTITY_ID,
}
)
msg = await client.receive_json()
@ -65,12 +65,15 @@ async def test_run_task_preferred_entity(
assert state is not None
assert state.state == STATE_UNKNOWN
result = await async_generate_text(
result = await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
)
assert result.text == "Mock result"
assert result.data == "Mock result"
as_dict = result.as_dict()
assert as_dict["conversation_id"] == result.conversation_id
assert as_dict["data"] == "Mock result"
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state != STATE_UNKNOWN
@ -78,25 +81,25 @@ async def test_run_task_preferred_entity(
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support generating text",
match="AI Task entity ai_task.test_task_entity does not support generating data",
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
)
async def test_run_text_task_unknown_entity(
async def test_run_data_task_unknown_entity(
hass: HomeAssistant,
init_components: None,
) -> None:
"""Test running a text task with an unknown entity."""
"""Test running a data task with an unknown entity."""
with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.unknown_entity",
@ -105,19 +108,19 @@ async def test_run_text_task_unknown_entity(
@freeze_time("2025-06-14 22:59:00")
async def test_run_text_task_updates_chat_log(
async def test_run_data_task_updates_chat_log(
hass: HomeAssistant,
init_components: None,
snapshot: SnapshotAssertion,
) -> None:
"""Test that running a text task updates the chat log."""
result = await async_generate_text(
"""Test that running a data task updates the chat log."""
result = await async_generate_data(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert result.text == "Mock result"
assert result.data == "Mock result"
with (
chat_session.async_get_chat_session(hass, result.conversation_id) as session,