mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 22:37:11 +00:00
Address AI Task late comments (#147313)
This commit is contained in:
parent
f8267b13d7
commit
a11e274434
@ -5,6 +5,7 @@ import logging
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.core import (
|
||||
HassJobType,
|
||||
HomeAssistant,
|
||||
@ -17,9 +18,17 @@ from homeassistant.helpers import config_validation as cv, storage
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN, AITaskEntityFeature
|
||||
from .const import (
|
||||
ATTR_INSTRUCTIONS,
|
||||
ATTR_TASK_NAME,
|
||||
DATA_COMPONENT,
|
||||
DATA_PREFERENCES,
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_TEXT,
|
||||
AITaskEntityFeature,
|
||||
)
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_conversation_http
|
||||
from .http import async_setup as async_setup_http
|
||||
from .task import GenTextTask, GenTextTaskResult, async_generate_text
|
||||
|
||||
__all__ = [
|
||||
@ -45,16 +54,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
hass.data[DATA_COMPONENT] = entity_component
|
||||
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
||||
await hass.data[DATA_PREFERENCES].async_load()
|
||||
async_setup_conversation_http(hass)
|
||||
async_setup_http(hass)
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
"generate_text",
|
||||
SERVICE_GENERATE_TEXT,
|
||||
async_service_generate_text,
|
||||
schema=vol.Schema(
|
||||
{
|
||||
vol.Required("task_name"): cv.string,
|
||||
vol.Optional("entity_id"): cv.entity_id,
|
||||
vol.Required("instructions"): cv.string,
|
||||
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
|
||||
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
||||
}
|
||||
),
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntFlag
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
@ -17,6 +17,11 @@ DOMAIN = "ai_task"
|
||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||
|
||||
SERVICE_GENERATE_TEXT = "generate_text"
|
||||
|
||||
ATTR_INSTRUCTIONS: Final = "instructions"
|
||||
ATTR_TASK_NAME: Final = "task_name"
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a Home Assistant expert and help users with their tasks."
|
||||
)
|
||||
|
@ -6,7 +6,7 @@ generate_text:
|
||||
selector:
|
||||
text:
|
||||
instructions:
|
||||
example: "Generate a funny notification that garage door was left open"
|
||||
example: "Generate a funny notification that the garage door was left open"
|
||||
required: true
|
||||
selector:
|
||||
text:
|
||||
|
@ -5,7 +5,7 @@
|
||||
"description": "Use AI to run a task that generates text.",
|
||||
"fields": {
|
||||
"task_name": {
|
||||
"name": "Task Name",
|
||||
"name": "Task name",
|
||||
"description": "Name of the task."
|
||||
},
|
||||
"instructions": {
|
||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
|
||||
@ -21,14 +22,16 @@ async def async_generate_text(
|
||||
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
|
||||
|
||||
if entity_id is None:
|
||||
raise ValueError("No entity_id provided and no preferred entity set")
|
||||
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
||||
|
||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||
if entity is None:
|
||||
raise ValueError(f"AI Task entity {entity_id} not found")
|
||||
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||
|
||||
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
|
||||
raise ValueError(f"AI Task entity {entity_id} does not support generating text")
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support generating text"
|
||||
)
|
||||
|
||||
return await entity.internal_async_generate_text(
|
||||
GenTextTask(
|
||||
|
@ -8,6 +8,7 @@ from homeassistant.components.ai_task import AITaskEntityFeature, async_generate
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session
|
||||
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
@ -25,7 +26,7 @@ async def test_run_task_preferred_entity(
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="No entity_id provided and no preferred entity set"
|
||||
HomeAssistantError, match="No entity_id provided and no preferred entity set"
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
@ -42,7 +43,9 @@ async def test_run_task_preferred_entity(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"):
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
@ -74,7 +77,7 @@ async def test_run_task_preferred_entity(
|
||||
|
||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
HomeAssistantError,
|
||||
match="AI Task entity ai_task.test_task_entity does not support generating text",
|
||||
):
|
||||
await async_generate_text(
|
||||
@ -91,7 +94,7 @@ async def test_run_text_task_unknown_entity(
|
||||
"""Test running a text task with an unknown entity."""
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="AI Task entity ai_task.unknown_entity not found"
|
||||
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
||||
):
|
||||
await async_generate_text(
|
||||
hass,
|
||||
|
Loading…
x
Reference in New Issue
Block a user