AI task generate_text -> generate_data (#147370)

This commit is contained in:
Paulus Schoutsen 2025-06-24 07:12:29 -04:00 committed by GitHub
parent 38c7eaf70a
commit 63ac14a19b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 104 additions and 100 deletions

View File

@ -24,20 +24,20 @@ from .const import (
DATA_COMPONENT, DATA_COMPONENT,
DATA_PREFERENCES, DATA_PREFERENCES,
DOMAIN, DOMAIN,
SERVICE_GENERATE_TEXT, SERVICE_GENERATE_DATA,
AITaskEntityFeature, AITaskEntityFeature,
) )
from .entity import AITaskEntity from .entity import AITaskEntity
from .http import async_setup as async_setup_http from .http import async_setup as async_setup_http
from .task import GenTextTask, GenTextTaskResult, async_generate_text from .task import GenDataTask, GenDataTaskResult, async_generate_data
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"AITaskEntity", "AITaskEntity",
"AITaskEntityFeature", "AITaskEntityFeature",
"GenTextTask", "GenDataTask",
"GenTextTaskResult", "GenDataTaskResult",
"async_generate_text", "async_generate_data",
"async_setup", "async_setup",
"async_setup_entry", "async_setup_entry",
"async_unload_entry", "async_unload_entry",
@ -57,8 +57,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_setup_http(hass) async_setup_http(hass)
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
SERVICE_GENERATE_TEXT, SERVICE_GENERATE_DATA,
async_service_generate_text, async_service_generate_data,
schema=vol.Schema( schema=vol.Schema(
{ {
vol.Required(ATTR_TASK_NAME): cv.string, vol.Required(ATTR_TASK_NAME): cv.string,
@ -82,18 +82,18 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return await hass.data[DATA_COMPONENT].async_unload_entry(entry) return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
async def async_service_generate_text(call: ServiceCall) -> ServiceResponse: async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
"""Run the run task service.""" """Run the run task service."""
result = await async_generate_text(hass=call.hass, **call.data) result = await async_generate_data(hass=call.hass, **call.data)
return result.as_dict() # type: ignore[return-value] return result.as_dict()
class AITaskPreferences: class AITaskPreferences:
"""AI Task preferences.""" """AI Task preferences."""
KEYS = ("gen_text_entity_id",) KEYS = ("gen_data_entity_id",)
gen_text_entity_id: str | None = None gen_data_entity_id: str | None = None
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the preferences.""" """Initialize the preferences."""
@ -113,11 +113,11 @@ class AITaskPreferences:
def async_set_preferences( def async_set_preferences(
self, self,
*, *,
gen_text_entity_id: str | None | UndefinedType = UNDEFINED, gen_data_entity_id: str | None | UndefinedType = UNDEFINED,
) -> None: ) -> None:
"""Set the preferences.""" """Set the preferences."""
changed = False changed = False
for key, value in (("gen_text_entity_id", gen_text_entity_id),): for key, value in (("gen_data_entity_id", gen_data_entity_id),):
if value is not UNDEFINED: if value is not UNDEFINED:
if getattr(self, key) != value: if getattr(self, key) != value:
setattr(self, key, value) setattr(self, key, value)

View File

@ -17,7 +17,7 @@ DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
SERVICE_GENERATE_TEXT = "generate_text" SERVICE_GENERATE_DATA = "generate_data"
ATTR_INSTRUCTIONS: Final = "instructions" ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name" ATTR_TASK_NAME: Final = "task_name"
@ -30,5 +30,5 @@ DEFAULT_SYSTEM_PROMPT = (
class AITaskEntityFeature(IntFlag): class AITaskEntityFeature(IntFlag):
"""Supported features of the AI task entity.""" """Supported features of the AI task entity."""
GENERATE_TEXT = 1 GENERATE_DATA = 1
"""Generate text based on instructions.""" """Generate data based on instructions."""

View File

@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
from .task import GenTextTask, GenTextTaskResult from .task import GenDataTask, GenDataTaskResult
class AITaskEntity(RestoreEntity): class AITaskEntity(RestoreEntity):
@ -56,7 +56,7 @@ class AITaskEntity(RestoreEntity):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _async_get_ai_task_chat_log( async def _async_get_ai_task_chat_log(
self, self,
task: GenTextTask, task: GenDataTask,
) -> AsyncGenerator[ChatLog]: ) -> AsyncGenerator[ChatLog]:
"""Context manager used to manage the ChatLog used during an AI Task.""" """Context manager used to manage the ChatLog used during an AI Task."""
# pylint: disable-next=contextmanager-generator-missing-cleanup # pylint: disable-next=contextmanager-generator-missing-cleanup
@ -84,20 +84,20 @@ class AITaskEntity(RestoreEntity):
yield chat_log yield chat_log
@final @final
async def internal_async_generate_text( async def internal_async_generate_data(
self, self,
task: GenTextTask, task: GenDataTask,
) -> GenTextTaskResult: ) -> GenDataTaskResult:
"""Run a gen text task.""" """Run a gen data task."""
self.__last_activity = dt_util.utcnow().isoformat() self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state() self.async_write_ha_state()
async with self._async_get_ai_task_chat_log(task) as chat_log: async with self._async_get_ai_task_chat_log(task) as chat_log:
return await self._async_generate_text(task, chat_log) return await self._async_generate_data(task, chat_log)
async def _async_generate_text( async def _async_generate_data(
self, self,
task: GenTextTask, task: GenDataTask,
chat_log: ChatLog, chat_log: ChatLog,
) -> GenTextTaskResult: ) -> GenDataTaskResult:
"""Handle a gen text task.""" """Handle a gen data task."""
raise NotImplementedError raise NotImplementedError

View File

@ -36,7 +36,7 @@ def websocket_get_preferences(
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "ai_task/preferences/set", vol.Required("type"): "ai_task/preferences/set",
vol.Optional("gen_text_entity_id"): vol.Any(str, None), vol.Optional("gen_data_entity_id"): vol.Any(str, None),
} }
) )
@websocket_api.require_admin @websocket_api.require_admin

View File

@ -1,6 +1,6 @@
{ {
"services": { "services": {
"generate_text": { "generate_data": {
"service": "mdi:file-star-four-points-outline" "service": "mdi:file-star-four-points-outline"
} }
} }

View File

@ -1,4 +1,4 @@
generate_text: generate_data:
fields: fields:
task_name: task_name:
example: "home summary" example: "home summary"
@ -16,4 +16,4 @@ generate_text:
entity: entity:
domain: ai_task domain: ai_task
supported_features: supported_features:
- ai_task.AITaskEntityFeature.GENERATE_TEXT - ai_task.AITaskEntityFeature.GENERATE_DATA

View File

@ -1,8 +1,8 @@
{ {
"services": { "services": {
"generate_text": { "generate_data": {
"name": "Generate text", "name": "Generate data",
"description": "Use AI to run a task that generates text.", "description": "Uses AI to run a task that generates data.",
"fields": { "fields": {
"task_name": { "task_name": {
"name": "Task name", "name": "Task name",

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -10,16 +11,16 @@ from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
async def async_generate_text( async def async_generate_data(
hass: HomeAssistant, hass: HomeAssistant,
*, *,
task_name: str, task_name: str,
entity_id: str | None = None, entity_id: str | None = None,
instructions: str, instructions: str,
) -> GenTextTaskResult: ) -> GenDataTaskResult:
"""Run a task in the AI Task integration.""" """Run a task in the AI Task integration."""
if entity_id is None: if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None: if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set") raise HomeAssistantError("No entity_id provided and no preferred entity set")
@ -28,13 +29,13 @@ async def async_generate_text(
if entity is None: if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found") raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features: if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError( raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating text" f"AI Task entity {entity_id} does not support generating data"
) )
return await entity.internal_async_generate_text( return await entity.internal_async_generate_data(
GenTextTask( GenDataTask(
name=task_name, name=task_name,
instructions=instructions, instructions=instructions,
) )
@ -42,8 +43,8 @@ async def async_generate_text(
@dataclass(slots=True) @dataclass(slots=True)
class GenTextTask: class GenDataTask:
"""Gen text task to be processed.""" """Gen data task to be processed."""
name: str name: str
"""Name of the task.""" """Name of the task."""
@ -53,22 +54,22 @@ class GenTextTask:
def __str__(self) -> str: def __str__(self) -> str:
"""Return task as a string.""" """Return task as a string."""
return f"<GenTextTask {self.name}: {id(self)}>" return f"<GenDataTask {self.name}: {id(self)}>"
@dataclass(slots=True) @dataclass(slots=True)
class GenTextTaskResult: class GenDataTaskResult:
"""Result of gen text task.""" """Result of gen data task."""
conversation_id: str conversation_id: str
"""Unique identifier for the conversation.""" """Unique identifier for the conversation."""
text: str data: Any
"""Generated text.""" """Data generated by the task."""
def as_dict(self) -> dict[str, str]: def as_dict(self) -> dict[str, Any]:
"""Return result as a dict.""" """Return result as a dict."""
return { return {
"conversation_id": self.conversation_id, "conversation_id": self.conversation_id,
"text": self.text, "data": self.data,
} }

View File

@ -6,8 +6,8 @@ from homeassistant.components.ai_task import (
DOMAIN, DOMAIN,
AITaskEntity, AITaskEntity,
AITaskEntityFeature, AITaskEntityFeature,
GenTextTask, GenDataTask,
GenTextTaskResult, GenDataTaskResult,
) )
from homeassistant.components.conversation import AssistantContent, ChatLog from homeassistant.components.conversation import AssistantContent, ChatLog
from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
@ -33,24 +33,24 @@ class MockAITaskEntity(AITaskEntity):
"""Mock AI Task entity for testing.""" """Mock AI Task entity for testing."""
_attr_name = "Test Task Entity" _attr_name = "Test Task Entity"
_attr_supported_features = AITaskEntityFeature.GENERATE_TEXT _attr_supported_features = AITaskEntityFeature.GENERATE_DATA
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the mock entity.""" """Initialize the mock entity."""
super().__init__() super().__init__()
self.mock_generate_text_tasks = [] self.mock_generate_data_tasks = []
async def _async_generate_text( async def _async_generate_data(
self, task: GenTextTask, chat_log: ChatLog self, task: GenDataTask, chat_log: ChatLog
) -> GenTextTaskResult: ) -> GenDataTaskResult:
"""Mock handling of generate text task.""" """Mock handling of generate data task."""
self.mock_generate_text_tasks.append(task) self.mock_generate_data_tasks.append(task)
chat_log.async_add_assistant_content_without_tools( chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "Mock result") AssistantContent(self.entity_id, "Mock result")
) )
return GenTextTaskResult( return GenDataTaskResult(
conversation_id=chat_log.conversation_id, conversation_id=chat_log.conversation_id,
text="Mock result", data="Mock result",
) )

View File

@ -1,5 +1,5 @@
# serializer version: 1 # serializer version: 1
# name: test_run_text_task_updates_chat_log # name: test_run_data_task_updates_chat_log
list([ list([
dict({ dict({
'content': ''' 'content': '''

View File

@ -2,7 +2,7 @@
from freezegun import freeze_time from freezegun import freeze_time
from homeassistant.components.ai_task import async_generate_text from homeassistant.components.ai_task import async_generate_data
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -12,28 +12,28 @@ from tests.common import MockConfigEntry
@freeze_time("2025-06-08 16:28:13") @freeze_time("2025-06-08 16:28:13")
async def test_state_generate_text( async def test_state_generate_data(
hass: HomeAssistant, hass: HomeAssistant,
init_components: None, init_components: None,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_ai_task_entity: MockAITaskEntity, mock_ai_task_entity: MockAITaskEntity,
) -> None: ) -> None:
"""Test the state of the AI Task entity is updated when generating text.""" """Test the state of the AI Task entity is updated when generating data."""
entity = hass.states.get(TEST_ENTITY_ID) entity = hass.states.get(TEST_ENTITY_ID)
assert entity is not None assert entity is not None
assert entity.state == STATE_UNKNOWN assert entity.state == STATE_UNKNOWN
result = await async_generate_text( result = await async_generate_data(
hass, hass,
task_name="Test task", task_name="Test task",
entity_id=TEST_ENTITY_ID, entity_id=TEST_ENTITY_ID,
instructions="Test prompt", instructions="Test prompt",
) )
assert result.text == "Mock result" assert result.data == "Mock result"
entity = hass.states.get(TEST_ENTITY_ID) entity = hass.states.get(TEST_ENTITY_ID)
assert entity.state == "2025-06-08T16:28:13+00:00" assert entity.state == "2025-06-08T16:28:13+00:00"
assert mock_ai_task_entity.mock_generate_text_tasks assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_text_tasks[0] task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Test prompt" assert task.instructions == "Test prompt"

View File

@ -18,20 +18,20 @@ async def test_ws_preferences(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"gen_text_entity_id": None, "gen_data_entity_id": None,
} }
# Set preferences # Set preferences
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "ai_task/preferences/set", "type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.summary_1", "gen_data_entity_id": "ai_task.summary_1",
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_1", "gen_data_entity_id": "ai_task.summary_1",
} }
# Get updated preferences # Get updated preferences
@ -39,20 +39,20 @@ async def test_ws_preferences(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_1", "gen_data_entity_id": "ai_task.summary_1",
} }
# Update an existing preference # Update an existing preference
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "ai_task/preferences/set", "type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.summary_2", "gen_data_entity_id": "ai_task.summary_2",
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2", "gen_data_entity_id": "ai_task.summary_2",
} }
# Get updated preferences # Get updated preferences
@ -60,7 +60,7 @@ async def test_ws_preferences(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2", "gen_data_entity_id": "ai_task.summary_2",
} }
# No preferences set will preserve existing preferences # No preferences set will preserve existing preferences
@ -72,7 +72,7 @@ async def test_ws_preferences(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2", "gen_data_entity_id": "ai_task.summary_2",
} }
# Get updated preferences # Get updated preferences
@ -80,5 +80,5 @@ async def test_ws_preferences(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == { assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2", "gen_data_entity_id": "ai_task.summary_2",
} }

View File

@ -49,7 +49,7 @@ async def test_preferences_storage_load(
("set_preferences", "msg_extra"), ("set_preferences", "msg_extra"),
[ [
( (
{"gen_text_entity_id": TEST_ENTITY_ID}, {"gen_data_entity_id": TEST_ENTITY_ID},
{}, {},
), ),
( (
@ -58,20 +58,20 @@ async def test_preferences_storage_load(
), ),
], ],
) )
async def test_generate_text_service( async def test_generate_data_service(
hass: HomeAssistant, hass: HomeAssistant,
init_components: None, init_components: None,
freezer: FrozenDateTimeFactory, freezer: FrozenDateTimeFactory,
set_preferences: dict[str, str | None], set_preferences: dict[str, str | None],
msg_extra: dict[str, str], msg_extra: dict[str, str],
) -> None: ) -> None:
"""Test the generate text service.""" """Test the generate data service."""
preferences = hass.data[DATA_PREFERENCES] preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences) preferences.async_set_preferences(**set_preferences)
result = await hass.services.async_call( result = await hass.services.async_call(
"ai_task", "ai_task",
"generate_text", "generate_data",
{ {
"task_name": "Test Name", "task_name": "Test Name",
"instructions": "Test prompt", "instructions": "Test prompt",
@ -81,4 +81,4 @@ async def test_generate_text_service(
return_response=True, return_response=True,
) )
assert result["text"] == "Mock result" assert result["data"] == "Mock result"

View File

@ -4,7 +4,7 @@ from freezegun import freeze_time
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_text from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
from homeassistant.components.conversation import async_get_chat_log from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -28,7 +28,7 @@ async def test_run_task_preferred_entity(
with pytest.raises( with pytest.raises(
HomeAssistantError, match="No entity_id provided and no preferred entity set" HomeAssistantError, match="No entity_id provided and no preferred entity set"
): ):
await async_generate_text( await async_generate_data(
hass, hass,
task_name="Test Task", task_name="Test Task",
instructions="Test prompt", instructions="Test prompt",
@ -37,7 +37,7 @@ async def test_run_task_preferred_entity(
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "ai_task/preferences/set", "type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.unknown", "gen_data_entity_id": "ai_task.unknown",
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
@ -46,7 +46,7 @@ async def test_run_task_preferred_entity(
with pytest.raises( with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown not found" HomeAssistantError, match="AI Task entity ai_task.unknown not found"
): ):
await async_generate_text( await async_generate_data(
hass, hass,
task_name="Test Task", task_name="Test Task",
instructions="Test prompt", instructions="Test prompt",
@ -55,7 +55,7 @@ async def test_run_task_preferred_entity(
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "ai_task/preferences/set", "type": "ai_task/preferences/set",
"gen_text_entity_id": TEST_ENTITY_ID, "gen_data_entity_id": TEST_ENTITY_ID,
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
@ -65,12 +65,15 @@ async def test_run_task_preferred_entity(
assert state is not None assert state is not None
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
result = await async_generate_text( result = await async_generate_data(
hass, hass,
task_name="Test Task", task_name="Test Task",
instructions="Test prompt", instructions="Test prompt",
) )
assert result.text == "Mock result" assert result.data == "Mock result"
as_dict = result.as_dict()
assert as_dict["conversation_id"] == result.conversation_id
assert as_dict["data"] == "Mock result"
state = hass.states.get(TEST_ENTITY_ID) state = hass.states.get(TEST_ENTITY_ID)
assert state is not None assert state is not None
assert state.state != STATE_UNKNOWN assert state.state != STATE_UNKNOWN
@ -78,25 +81,25 @@ async def test_run_task_preferred_entity(
mock_ai_task_entity.supported_features = AITaskEntityFeature(0) mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
with pytest.raises( with pytest.raises(
HomeAssistantError, HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support generating text", match="AI Task entity ai_task.test_task_entity does not support generating data",
): ):
await async_generate_text( await async_generate_data(
hass, hass,
task_name="Test Task", task_name="Test Task",
instructions="Test prompt", instructions="Test prompt",
) )
async def test_run_text_task_unknown_entity( async def test_run_data_task_unknown_entity(
hass: HomeAssistant, hass: HomeAssistant,
init_components: None, init_components: None,
) -> None: ) -> None:
"""Test running a text task with an unknown entity.""" """Test running a data task with an unknown entity."""
with pytest.raises( with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found" HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
): ):
await async_generate_text( await async_generate_data(
hass, hass,
task_name="Test Task", task_name="Test Task",
entity_id="ai_task.unknown_entity", entity_id="ai_task.unknown_entity",
@ -105,19 +108,19 @@ async def test_run_text_task_unknown_entity(
@freeze_time("2025-06-14 22:59:00") @freeze_time("2025-06-14 22:59:00")
async def test_run_text_task_updates_chat_log( async def test_run_data_task_updates_chat_log(
hass: HomeAssistant, hass: HomeAssistant,
init_components: None, init_components: None,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test that running a text task updates the chat log.""" """Test that running a data task updates the chat log."""
result = await async_generate_text( result = await async_generate_data(
hass, hass,
task_name="Test Task", task_name="Test Task",
entity_id=TEST_ENTITY_ID, entity_id=TEST_ENTITY_ID,
instructions="Test prompt", instructions="Test prompt",
) )
assert result.text == "Mock result" assert result.data == "Mock result"
with ( with (
chat_session.async_get_chat_session(hass, result.conversation_id) as session, chat_session.async_get_chat_session(hass, result.conversation_id) as session,