Address AI Task late comments (#147313)

This commit is contained in:
Paulus Schoutsen 2025-06-23 10:58:42 -04:00 committed by GitHub
parent f8267b13d7
commit a11e274434
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 17 deletions

View File

@ -5,6 +5,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import ( from homeassistant.core import (
HassJobType, HassJobType,
HomeAssistant, HomeAssistant,
@ -17,9 +18,17 @@ from homeassistant.helpers import config_validation as cv, storage
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN, AITaskEntityFeature from .const import (
ATTR_INSTRUCTIONS,
ATTR_TASK_NAME,
DATA_COMPONENT,
DATA_PREFERENCES,
DOMAIN,
SERVICE_GENERATE_TEXT,
AITaskEntityFeature,
)
from .entity import AITaskEntity from .entity import AITaskEntity
from .http import async_setup as async_setup_conversation_http from .http import async_setup as async_setup_http
from .task import GenTextTask, GenTextTaskResult, async_generate_text from .task import GenTextTask, GenTextTaskResult, async_generate_text
__all__ = [ __all__ = [
@ -45,16 +54,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.data[DATA_COMPONENT] = entity_component hass.data[DATA_COMPONENT] = entity_component
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
await hass.data[DATA_PREFERENCES].async_load() await hass.data[DATA_PREFERENCES].async_load()
async_setup_conversation_http(hass) async_setup_http(hass)
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
"generate_text", SERVICE_GENERATE_TEXT,
async_service_generate_text, async_service_generate_text,
schema=vol.Schema( schema=vol.Schema(
{ {
vol.Required("task_name"): cv.string, vol.Required(ATTR_TASK_NAME): cv.string,
vol.Optional("entity_id"): cv.entity_id, vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
vol.Required("instructions"): cv.string, vol.Required(ATTR_INSTRUCTIONS): cv.string,
} }
), ),
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from enum import IntFlag from enum import IntFlag
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Final
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
@ -17,6 +17,11 @@ DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
SERVICE_GENERATE_TEXT = "generate_text"
ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name"
DEFAULT_SYSTEM_PROMPT = ( DEFAULT_SYSTEM_PROMPT = (
"You are a Home Assistant expert and help users with their tasks." "You are a Home Assistant expert and help users with their tasks."
) )

View File

@ -6,7 +6,7 @@ generate_text:
selector: selector:
text: text:
instructions: instructions:
example: "Generate a funny notification that garage door was left open" example: "Generate a funny notification that the garage door was left open"
required: true required: true
selector: selector:
text: text:

View File

@ -5,7 +5,7 @@
"description": "Use AI to run a task that generates text.", "description": "Use AI to run a task that generates text.",
"fields": { "fields": {
"task_name": { "task_name": {
"name": "Task Name", "name": "Task name",
"description": "Name of the task." "description": "Name of the task."
}, },
"instructions": { "instructions": {

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
@ -21,14 +22,16 @@ async def async_generate_text(
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
if entity_id is None: if entity_id is None:
raise ValueError("No entity_id provided and no preferred entity set") raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id) entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None: if entity is None:
raise ValueError(f"AI Task entity {entity_id} not found") raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features: if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
raise ValueError(f"AI Task entity {entity_id} does not support generating text") raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating text"
)
return await entity.internal_async_generate_text( return await entity.internal_async_generate_text(
GenTextTask( GenTextTask(

View File

@ -8,6 +8,7 @@ from homeassistant.components.ai_task import AITaskEntityFeature, async_generate
from homeassistant.components.conversation import async_get_chat_log from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session from homeassistant.helpers import chat_session
from .conftest import TEST_ENTITY_ID, MockAITaskEntity from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@ -25,7 +26,7 @@ async def test_run_task_preferred_entity(
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
with pytest.raises( with pytest.raises(
ValueError, match="No entity_id provided and no preferred entity set" HomeAssistantError, match="No entity_id provided and no preferred entity set"
): ):
await async_generate_text( await async_generate_text(
hass, hass,
@ -42,7 +43,9 @@ async def test_run_task_preferred_entity(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"): with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
):
await async_generate_text( await async_generate_text(
hass, hass,
task_name="Test Task", task_name="Test Task",
@ -74,7 +77,7 @@ async def test_run_task_preferred_entity(
mock_ai_task_entity.supported_features = AITaskEntityFeature(0) mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
with pytest.raises( with pytest.raises(
ValueError, HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support generating text", match="AI Task entity ai_task.test_task_entity does not support generating text",
): ):
await async_generate_text( await async_generate_text(
@ -91,7 +94,7 @@ async def test_run_text_task_unknown_entity(
"""Test running a text task with an unknown entity.""" """Test running a text task with an unknown entity."""
with pytest.raises( with pytest.raises(
ValueError, match="AI Task entity ai_task.unknown_entity not found" HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
): ):
await async_generate_text( await async_generate_text(
hass, hass,