mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
AI task generate_text -> generate_data (#147370)
This commit is contained in:
parent
38c7eaf70a
commit
63ac14a19b
@ -24,20 +24,20 @@ from .const import (
|
|||||||
DATA_COMPONENT,
|
DATA_COMPONENT,
|
||||||
DATA_PREFERENCES,
|
DATA_PREFERENCES,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
SERVICE_GENERATE_TEXT,
|
SERVICE_GENERATE_DATA,
|
||||||
AITaskEntityFeature,
|
AITaskEntityFeature,
|
||||||
)
|
)
|
||||||
from .entity import AITaskEntity
|
from .entity import AITaskEntity
|
||||||
from .http import async_setup as async_setup_http
|
from .http import async_setup as async_setup_http
|
||||||
from .task import GenTextTask, GenTextTaskResult, async_generate_text
|
from .task import GenDataTask, GenDataTaskResult, async_generate_data
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
"AITaskEntity",
|
"AITaskEntity",
|
||||||
"AITaskEntityFeature",
|
"AITaskEntityFeature",
|
||||||
"GenTextTask",
|
"GenDataTask",
|
||||||
"GenTextTaskResult",
|
"GenDataTaskResult",
|
||||||
"async_generate_text",
|
"async_generate_data",
|
||||||
"async_setup",
|
"async_setup",
|
||||||
"async_setup_entry",
|
"async_setup_entry",
|
||||||
"async_unload_entry",
|
"async_unload_entry",
|
||||||
@ -57,8 +57,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
async_setup_http(hass)
|
async_setup_http(hass)
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
SERVICE_GENERATE_TEXT,
|
SERVICE_GENERATE_DATA,
|
||||||
async_service_generate_text,
|
async_service_generate_data,
|
||||||
schema=vol.Schema(
|
schema=vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Required(ATTR_TASK_NAME): cv.string,
|
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||||
@ -82,18 +82,18 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
||||||
|
|
||||||
|
|
||||||
async def async_service_generate_text(call: ServiceCall) -> ServiceResponse:
|
async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
|
||||||
"""Run the run task service."""
|
"""Run the run task service."""
|
||||||
result = await async_generate_text(hass=call.hass, **call.data)
|
result = await async_generate_data(hass=call.hass, **call.data)
|
||||||
return result.as_dict() # type: ignore[return-value]
|
return result.as_dict()
|
||||||
|
|
||||||
|
|
||||||
class AITaskPreferences:
|
class AITaskPreferences:
|
||||||
"""AI Task preferences."""
|
"""AI Task preferences."""
|
||||||
|
|
||||||
KEYS = ("gen_text_entity_id",)
|
KEYS = ("gen_data_entity_id",)
|
||||||
|
|
||||||
gen_text_entity_id: str | None = None
|
gen_data_entity_id: str | None = None
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the preferences."""
|
"""Initialize the preferences."""
|
||||||
@ -113,11 +113,11 @@ class AITaskPreferences:
|
|||||||
def async_set_preferences(
|
def async_set_preferences(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
gen_text_entity_id: str | None | UndefinedType = UNDEFINED,
|
gen_data_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the preferences."""
|
"""Set the preferences."""
|
||||||
changed = False
|
changed = False
|
||||||
for key, value in (("gen_text_entity_id", gen_text_entity_id),):
|
for key, value in (("gen_data_entity_id", gen_data_entity_id),):
|
||||||
if value is not UNDEFINED:
|
if value is not UNDEFINED:
|
||||||
if getattr(self, key) != value:
|
if getattr(self, key) != value:
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
@ -17,7 +17,7 @@ DOMAIN = "ai_task"
|
|||||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||||
|
|
||||||
SERVICE_GENERATE_TEXT = "generate_text"
|
SERVICE_GENERATE_DATA = "generate_data"
|
||||||
|
|
||||||
ATTR_INSTRUCTIONS: Final = "instructions"
|
ATTR_INSTRUCTIONS: Final = "instructions"
|
||||||
ATTR_TASK_NAME: Final = "task_name"
|
ATTR_TASK_NAME: Final = "task_name"
|
||||||
@ -30,5 +30,5 @@ DEFAULT_SYSTEM_PROMPT = (
|
|||||||
class AITaskEntityFeature(IntFlag):
|
class AITaskEntityFeature(IntFlag):
|
||||||
"""Supported features of the AI task entity."""
|
"""Supported features of the AI task entity."""
|
||||||
|
|
||||||
GENERATE_TEXT = 1
|
GENERATE_DATA = 1
|
||||||
"""Generate text based on instructions."""
|
"""Generate data based on instructions."""
|
||||||
|
@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
|
|||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
|
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
|
||||||
from .task import GenTextTask, GenTextTaskResult
|
from .task import GenDataTask, GenDataTaskResult
|
||||||
|
|
||||||
|
|
||||||
class AITaskEntity(RestoreEntity):
|
class AITaskEntity(RestoreEntity):
|
||||||
@ -56,7 +56,7 @@ class AITaskEntity(RestoreEntity):
|
|||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _async_get_ai_task_chat_log(
|
async def _async_get_ai_task_chat_log(
|
||||||
self,
|
self,
|
||||||
task: GenTextTask,
|
task: GenDataTask,
|
||||||
) -> AsyncGenerator[ChatLog]:
|
) -> AsyncGenerator[ChatLog]:
|
||||||
"""Context manager used to manage the ChatLog used during an AI Task."""
|
"""Context manager used to manage the ChatLog used during an AI Task."""
|
||||||
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
||||||
@ -84,20 +84,20 @@ class AITaskEntity(RestoreEntity):
|
|||||||
yield chat_log
|
yield chat_log
|
||||||
|
|
||||||
@final
|
@final
|
||||||
async def internal_async_generate_text(
|
async def internal_async_generate_data(
|
||||||
self,
|
self,
|
||||||
task: GenTextTask,
|
task: GenDataTask,
|
||||||
) -> GenTextTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Run a gen text task."""
|
"""Run a gen data task."""
|
||||||
self.__last_activity = dt_util.utcnow().isoformat()
|
self.__last_activity = dt_util.utcnow().isoformat()
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
async with self._async_get_ai_task_chat_log(task) as chat_log:
|
async with self._async_get_ai_task_chat_log(task) as chat_log:
|
||||||
return await self._async_generate_text(task, chat_log)
|
return await self._async_generate_data(task, chat_log)
|
||||||
|
|
||||||
async def _async_generate_text(
|
async def _async_generate_data(
|
||||||
self,
|
self,
|
||||||
task: GenTextTask,
|
task: GenDataTask,
|
||||||
chat_log: ChatLog,
|
chat_log: ChatLog,
|
||||||
) -> GenTextTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Handle a gen text task."""
|
"""Handle a gen data task."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -36,7 +36,7 @@ def websocket_get_preferences(
|
|||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "ai_task/preferences/set",
|
vol.Required("type"): "ai_task/preferences/set",
|
||||||
vol.Optional("gen_text_entity_id"): vol.Any(str, None),
|
vol.Optional("gen_data_entity_id"): vol.Any(str, None),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@websocket_api.require_admin
|
@websocket_api.require_admin
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"services": {
|
"services": {
|
||||||
"generate_text": {
|
"generate_data": {
|
||||||
"service": "mdi:file-star-four-points-outline"
|
"service": "mdi:file-star-four-points-outline"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
generate_text:
|
generate_data:
|
||||||
fields:
|
fields:
|
||||||
task_name:
|
task_name:
|
||||||
example: "home summary"
|
example: "home summary"
|
||||||
@ -16,4 +16,4 @@ generate_text:
|
|||||||
entity:
|
entity:
|
||||||
domain: ai_task
|
domain: ai_task
|
||||||
supported_features:
|
supported_features:
|
||||||
- ai_task.AITaskEntityFeature.GENERATE_TEXT
|
- ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"services": {
|
"services": {
|
||||||
"generate_text": {
|
"generate_data": {
|
||||||
"name": "Generate text",
|
"name": "Generate data",
|
||||||
"description": "Use AI to run a task that generates text.",
|
"description": "Uses AI to run a task that generates data.",
|
||||||
"fields": {
|
"fields": {
|
||||||
"task_name": {
|
"task_name": {
|
||||||
"name": "Task name",
|
"name": "Task name",
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@ -10,16 +11,16 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||||
|
|
||||||
|
|
||||||
async def async_generate_text(
|
async def async_generate_data(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
*,
|
*,
|
||||||
task_name: str,
|
task_name: str,
|
||||||
entity_id: str | None = None,
|
entity_id: str | None = None,
|
||||||
instructions: str,
|
instructions: str,
|
||||||
) -> GenTextTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Run a task in the AI Task integration."""
|
"""Run a task in the AI Task integration."""
|
||||||
if entity_id is None:
|
if entity_id is None:
|
||||||
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
|
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
|
||||||
|
|
||||||
if entity_id is None:
|
if entity_id is None:
|
||||||
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
||||||
@ -28,13 +29,13 @@ async def async_generate_text(
|
|||||||
if entity is None:
|
if entity is None:
|
||||||
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||||
|
|
||||||
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
|
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"AI Task entity {entity_id} does not support generating text"
|
f"AI Task entity {entity_id} does not support generating data"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await entity.internal_async_generate_text(
|
return await entity.internal_async_generate_data(
|
||||||
GenTextTask(
|
GenDataTask(
|
||||||
name=task_name,
|
name=task_name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
)
|
)
|
||||||
@ -42,8 +43,8 @@ async def async_generate_text(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class GenTextTask:
|
class GenDataTask:
|
||||||
"""Gen text task to be processed."""
|
"""Gen data task to be processed."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
"""Name of the task."""
|
"""Name of the task."""
|
||||||
@ -53,22 +54,22 @@ class GenTextTask:
|
|||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return task as a string."""
|
"""Return task as a string."""
|
||||||
return f"<GenTextTask {self.name}: {id(self)}>"
|
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class GenTextTaskResult:
|
class GenDataTaskResult:
|
||||||
"""Result of gen text task."""
|
"""Result of gen data task."""
|
||||||
|
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
"""Unique identifier for the conversation."""
|
"""Unique identifier for the conversation."""
|
||||||
|
|
||||||
text: str
|
data: Any
|
||||||
"""Generated text."""
|
"""Data generated by the task."""
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, str]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
"""Return result as a dict."""
|
"""Return result as a dict."""
|
||||||
return {
|
return {
|
||||||
"conversation_id": self.conversation_id,
|
"conversation_id": self.conversation_id,
|
||||||
"text": self.text,
|
"data": self.data,
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,8 @@ from homeassistant.components.ai_task import (
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
AITaskEntity,
|
AITaskEntity,
|
||||||
AITaskEntityFeature,
|
AITaskEntityFeature,
|
||||||
GenTextTask,
|
GenDataTask,
|
||||||
GenTextTaskResult,
|
GenDataTaskResult,
|
||||||
)
|
)
|
||||||
from homeassistant.components.conversation import AssistantContent, ChatLog
|
from homeassistant.components.conversation import AssistantContent, ChatLog
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||||
@ -33,24 +33,24 @@ class MockAITaskEntity(AITaskEntity):
|
|||||||
"""Mock AI Task entity for testing."""
|
"""Mock AI Task entity for testing."""
|
||||||
|
|
||||||
_attr_name = "Test Task Entity"
|
_attr_name = "Test Task Entity"
|
||||||
_attr_supported_features = AITaskEntityFeature.GENERATE_TEXT
|
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the mock entity."""
|
"""Initialize the mock entity."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mock_generate_text_tasks = []
|
self.mock_generate_data_tasks = []
|
||||||
|
|
||||||
async def _async_generate_text(
|
async def _async_generate_data(
|
||||||
self, task: GenTextTask, chat_log: ChatLog
|
self, task: GenDataTask, chat_log: ChatLog
|
||||||
) -> GenTextTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Mock handling of generate text task."""
|
"""Mock handling of generate data task."""
|
||||||
self.mock_generate_text_tasks.append(task)
|
self.mock_generate_data_tasks.append(task)
|
||||||
chat_log.async_add_assistant_content_without_tools(
|
chat_log.async_add_assistant_content_without_tools(
|
||||||
AssistantContent(self.entity_id, "Mock result")
|
AssistantContent(self.entity_id, "Mock result")
|
||||||
)
|
)
|
||||||
return GenTextTaskResult(
|
return GenDataTaskResult(
|
||||||
conversation_id=chat_log.conversation_id,
|
conversation_id=chat_log.conversation_id,
|
||||||
text="Mock result",
|
data="Mock result",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_run_text_task_updates_chat_log
|
# name: test_run_data_task_updates_chat_log
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'content': '''
|
'content': '''
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
|
||||||
from homeassistant.components.ai_task import async_generate_text
|
from homeassistant.components.ai_task import async_generate_data
|
||||||
from homeassistant.const import STATE_UNKNOWN
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
@ -12,28 +12,28 @@ from tests.common import MockConfigEntry
|
|||||||
|
|
||||||
|
|
||||||
@freeze_time("2025-06-08 16:28:13")
|
@freeze_time("2025-06-08 16:28:13")
|
||||||
async def test_state_generate_text(
|
async def test_state_generate_data(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: None,
|
init_components: None,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
mock_ai_task_entity: MockAITaskEntity,
|
mock_ai_task_entity: MockAITaskEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the state of the AI Task entity is updated when generating text."""
|
"""Test the state of the AI Task entity is updated when generating data."""
|
||||||
entity = hass.states.get(TEST_ENTITY_ID)
|
entity = hass.states.get(TEST_ENTITY_ID)
|
||||||
assert entity is not None
|
assert entity is not None
|
||||||
assert entity.state == STATE_UNKNOWN
|
assert entity.state == STATE_UNKNOWN
|
||||||
|
|
||||||
result = await async_generate_text(
|
result = await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test task",
|
task_name="Test task",
|
||||||
entity_id=TEST_ENTITY_ID,
|
entity_id=TEST_ENTITY_ID,
|
||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
)
|
)
|
||||||
assert result.text == "Mock result"
|
assert result.data == "Mock result"
|
||||||
|
|
||||||
entity = hass.states.get(TEST_ENTITY_ID)
|
entity = hass.states.get(TEST_ENTITY_ID)
|
||||||
assert entity.state == "2025-06-08T16:28:13+00:00"
|
assert entity.state == "2025-06-08T16:28:13+00:00"
|
||||||
|
|
||||||
assert mock_ai_task_entity.mock_generate_text_tasks
|
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||||
task = mock_ai_task_entity.mock_generate_text_tasks[0]
|
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||||
assert task.instructions == "Test prompt"
|
assert task.instructions == "Test prompt"
|
||||||
|
@ -18,20 +18,20 @@ async def test_ws_preferences(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"gen_text_entity_id": None,
|
"gen_data_entity_id": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set preferences
|
# Set preferences
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "ai_task/preferences/set",
|
"type": "ai_task/preferences/set",
|
||||||
"gen_text_entity_id": "ai_task.summary_1",
|
"gen_data_entity_id": "ai_task.summary_1",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"gen_text_entity_id": "ai_task.summary_1",
|
"gen_data_entity_id": "ai_task.summary_1",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get updated preferences
|
# Get updated preferences
|
||||||
@ -39,20 +39,20 @@ async def test_ws_preferences(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"gen_text_entity_id": "ai_task.summary_1",
|
"gen_data_entity_id": "ai_task.summary_1",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update an existing preference
|
# Update an existing preference
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "ai_task/preferences/set",
|
"type": "ai_task/preferences/set",
|
||||||
"gen_text_entity_id": "ai_task.summary_2",
|
"gen_data_entity_id": "ai_task.summary_2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"gen_text_entity_id": "ai_task.summary_2",
|
"gen_data_entity_id": "ai_task.summary_2",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get updated preferences
|
# Get updated preferences
|
||||||
@ -60,7 +60,7 @@ async def test_ws_preferences(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"gen_text_entity_id": "ai_task.summary_2",
|
"gen_data_entity_id": "ai_task.summary_2",
|
||||||
}
|
}
|
||||||
|
|
||||||
# No preferences set will preserve existing preferences
|
# No preferences set will preserve existing preferences
|
||||||
@ -72,7 +72,7 @@ async def test_ws_preferences(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"gen_text_entity_id": "ai_task.summary_2",
|
"gen_data_entity_id": "ai_task.summary_2",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get updated preferences
|
# Get updated preferences
|
||||||
@ -80,5 +80,5 @@ async def test_ws_preferences(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"gen_text_entity_id": "ai_task.summary_2",
|
"gen_data_entity_id": "ai_task.summary_2",
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ async def test_preferences_storage_load(
|
|||||||
("set_preferences", "msg_extra"),
|
("set_preferences", "msg_extra"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
{"gen_text_entity_id": TEST_ENTITY_ID},
|
{"gen_data_entity_id": TEST_ENTITY_ID},
|
||||||
{},
|
{},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -58,20 +58,20 @@ async def test_preferences_storage_load(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_generate_text_service(
|
async def test_generate_data_service(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: None,
|
init_components: None,
|
||||||
freezer: FrozenDateTimeFactory,
|
freezer: FrozenDateTimeFactory,
|
||||||
set_preferences: dict[str, str | None],
|
set_preferences: dict[str, str | None],
|
||||||
msg_extra: dict[str, str],
|
msg_extra: dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the generate text service."""
|
"""Test the generate data service."""
|
||||||
preferences = hass.data[DATA_PREFERENCES]
|
preferences = hass.data[DATA_PREFERENCES]
|
||||||
preferences.async_set_preferences(**set_preferences)
|
preferences.async_set_preferences(**set_preferences)
|
||||||
|
|
||||||
result = await hass.services.async_call(
|
result = await hass.services.async_call(
|
||||||
"ai_task",
|
"ai_task",
|
||||||
"generate_text",
|
"generate_data",
|
||||||
{
|
{
|
||||||
"task_name": "Test Name",
|
"task_name": "Test Name",
|
||||||
"instructions": "Test prompt",
|
"instructions": "Test prompt",
|
||||||
@ -81,4 +81,4 @@ async def test_generate_text_service(
|
|||||||
return_response=True,
|
return_response=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["text"] == "Mock result"
|
assert result["data"] == "Mock result"
|
||||||
|
@ -4,7 +4,7 @@ from freezegun import freeze_time
|
|||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_text
|
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
|
||||||
from homeassistant.components.conversation import async_get_chat_log
|
from homeassistant.components.conversation import async_get_chat_log
|
||||||
from homeassistant.const import STATE_UNKNOWN
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -28,7 +28,7 @@ async def test_run_task_preferred_entity(
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
HomeAssistantError, match="No entity_id provided and no preferred entity set"
|
HomeAssistantError, match="No entity_id provided and no preferred entity set"
|
||||||
):
|
):
|
||||||
await async_generate_text(
|
await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
@ -37,7 +37,7 @@ async def test_run_task_preferred_entity(
|
|||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "ai_task/preferences/set",
|
"type": "ai_task/preferences/set",
|
||||||
"gen_text_entity_id": "ai_task.unknown",
|
"gen_data_entity_id": "ai_task.unknown",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -46,7 +46,7 @@ async def test_run_task_preferred_entity(
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
|
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
|
||||||
):
|
):
|
||||||
await async_generate_text(
|
await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
@ -55,7 +55,7 @@ async def test_run_task_preferred_entity(
|
|||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "ai_task/preferences/set",
|
"type": "ai_task/preferences/set",
|
||||||
"gen_text_entity_id": TEST_ENTITY_ID,
|
"gen_data_entity_id": TEST_ENTITY_ID,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -65,12 +65,15 @@ async def test_run_task_preferred_entity(
|
|||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.state == STATE_UNKNOWN
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
result = await async_generate_text(
|
result = await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
)
|
)
|
||||||
assert result.text == "Mock result"
|
assert result.data == "Mock result"
|
||||||
|
as_dict = result.as_dict()
|
||||||
|
assert as_dict["conversation_id"] == result.conversation_id
|
||||||
|
assert as_dict["data"] == "Mock result"
|
||||||
state = hass.states.get(TEST_ENTITY_ID)
|
state = hass.states.get(TEST_ENTITY_ID)
|
||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.state != STATE_UNKNOWN
|
assert state.state != STATE_UNKNOWN
|
||||||
@ -78,25 +81,25 @@ async def test_run_task_preferred_entity(
|
|||||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
HomeAssistantError,
|
HomeAssistantError,
|
||||||
match="AI Task entity ai_task.test_task_entity does not support generating text",
|
match="AI Task entity ai_task.test_task_entity does not support generating data",
|
||||||
):
|
):
|
||||||
await async_generate_text(
|
await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_run_text_task_unknown_entity(
|
async def test_run_data_task_unknown_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: None,
|
init_components: None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test running a text task with an unknown entity."""
|
"""Test running a data task with an unknown entity."""
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
||||||
):
|
):
|
||||||
await async_generate_text(
|
await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
entity_id="ai_task.unknown_entity",
|
entity_id="ai_task.unknown_entity",
|
||||||
@ -105,19 +108,19 @@ async def test_run_text_task_unknown_entity(
|
|||||||
|
|
||||||
|
|
||||||
@freeze_time("2025-06-14 22:59:00")
|
@freeze_time("2025-06-14 22:59:00")
|
||||||
async def test_run_text_task_updates_chat_log(
|
async def test_run_data_task_updates_chat_log(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_components: None,
|
init_components: None,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that running a text task updates the chat log."""
|
"""Test that running a data task updates the chat log."""
|
||||||
result = await async_generate_text(
|
result = await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
entity_id=TEST_ENTITY_ID,
|
entity_id=TEST_ENTITY_ID,
|
||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
)
|
)
|
||||||
assert result.text == "Mock result"
|
assert result.data == "Mock result"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
|
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user