mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 00:07:10 +00:00
Address AI Task late comments (#147313)
This commit is contained in:
parent
f8267b13d7
commit
a11e274434
@ -5,6 +5,7 @@ import logging
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import ATTR_ENTITY_ID
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
HassJobType,
|
HassJobType,
|
||||||
HomeAssistant,
|
HomeAssistant,
|
||||||
@ -17,9 +18,17 @@ from homeassistant.helpers import config_validation as cv, storage
|
|||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||||
|
|
||||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, DOMAIN, AITaskEntityFeature
|
from .const import (
|
||||||
|
ATTR_INSTRUCTIONS,
|
||||||
|
ATTR_TASK_NAME,
|
||||||
|
DATA_COMPONENT,
|
||||||
|
DATA_PREFERENCES,
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_GENERATE_TEXT,
|
||||||
|
AITaskEntityFeature,
|
||||||
|
)
|
||||||
from .entity import AITaskEntity
|
from .entity import AITaskEntity
|
||||||
from .http import async_setup as async_setup_conversation_http
|
from .http import async_setup as async_setup_http
|
||||||
from .task import GenTextTask, GenTextTaskResult, async_generate_text
|
from .task import GenTextTask, GenTextTaskResult, async_generate_text
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -45,16 +54,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
hass.data[DATA_COMPONENT] = entity_component
|
hass.data[DATA_COMPONENT] = entity_component
|
||||||
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
||||||
await hass.data[DATA_PREFERENCES].async_load()
|
await hass.data[DATA_PREFERENCES].async_load()
|
||||||
async_setup_conversation_http(hass)
|
async_setup_http(hass)
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
"generate_text",
|
SERVICE_GENERATE_TEXT,
|
||||||
async_service_generate_text,
|
async_service_generate_text,
|
||||||
schema=vol.Schema(
|
schema=vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Required("task_name"): cv.string,
|
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||||
vol.Optional("entity_id"): cv.entity_id,
|
vol.Optional(ATTR_ENTITY_ID): cv.entity_id,
|
||||||
vol.Required("instructions"): cv.string,
|
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
supports_response=SupportsResponse.ONLY,
|
supports_response=SupportsResponse.ONLY,
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import IntFlag
|
from enum import IntFlag
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Final
|
||||||
|
|
||||||
from homeassistant.util.hass_dict import HassKey
|
from homeassistant.util.hass_dict import HassKey
|
||||||
|
|
||||||
@ -17,6 +17,11 @@ DOMAIN = "ai_task"
|
|||||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||||
|
|
||||||
|
SERVICE_GENERATE_TEXT = "generate_text"
|
||||||
|
|
||||||
|
ATTR_INSTRUCTIONS: Final = "instructions"
|
||||||
|
ATTR_TASK_NAME: Final = "task_name"
|
||||||
|
|
||||||
DEFAULT_SYSTEM_PROMPT = (
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
"You are a Home Assistant expert and help users with their tasks."
|
"You are a Home Assistant expert and help users with their tasks."
|
||||||
)
|
)
|
||||||
|
@ -6,7 +6,7 @@ generate_text:
|
|||||||
selector:
|
selector:
|
||||||
text:
|
text:
|
||||||
instructions:
|
instructions:
|
||||||
example: "Generate a funny notification that garage door was left open"
|
example: "Generate a funny notification that the garage door was left open"
|
||||||
required: true
|
required: true
|
||||||
selector:
|
selector:
|
||||||
text:
|
text:
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
"description": "Use AI to run a task that generates text.",
|
"description": "Use AI to run a task that generates text.",
|
||||||
"fields": {
|
"fields": {
|
||||||
"task_name": {
|
"task_name": {
|
||||||
"name": "Task Name",
|
"name": "Task name",
|
||||||
"description": "Name of the task."
|
"description": "Name of the task."
|
||||||
},
|
},
|
||||||
"instructions": {
|
"instructions": {
|
||||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||||
|
|
||||||
@ -21,14 +22,16 @@ async def async_generate_text(
|
|||||||
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
|
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
|
||||||
|
|
||||||
if entity_id is None:
|
if entity_id is None:
|
||||||
raise ValueError("No entity_id provided and no preferred entity set")
|
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
||||||
|
|
||||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||||
if entity is None:
|
if entity is None:
|
||||||
raise ValueError(f"AI Task entity {entity_id} not found")
|
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||||
|
|
||||||
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
|
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
|
||||||
raise ValueError(f"AI Task entity {entity_id} does not support generating text")
|
raise HomeAssistantError(
|
||||||
|
f"AI Task entity {entity_id} does not support generating text"
|
||||||
|
)
|
||||||
|
|
||||||
return await entity.internal_async_generate_text(
|
return await entity.internal_async_generate_text(
|
||||||
GenTextTask(
|
GenTextTask(
|
||||||
|
@ -8,6 +8,7 @@ from homeassistant.components.ai_task import AITaskEntityFeature, async_generate
|
|||||||
from homeassistant.components.conversation import async_get_chat_log
|
from homeassistant.components.conversation import async_get_chat_log
|
||||||
from homeassistant.const import STATE_UNKNOWN
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import chat_session
|
from homeassistant.helpers import chat_session
|
||||||
|
|
||||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||||
@ -25,7 +26,7 @@ async def test_run_task_preferred_entity(
|
|||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="No entity_id provided and no preferred entity set"
|
HomeAssistantError, match="No entity_id provided and no preferred entity set"
|
||||||
):
|
):
|
||||||
await async_generate_text(
|
await async_generate_text(
|
||||||
hass,
|
hass,
|
||||||
@ -42,7 +43,9 @@ async def test_run_task_preferred_entity(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="AI Task entity ai_task.unknown not found"):
|
with pytest.raises(
|
||||||
|
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
|
||||||
|
):
|
||||||
await async_generate_text(
|
await async_generate_text(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
@ -74,7 +77,7 @@ async def test_run_task_preferred_entity(
|
|||||||
|
|
||||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
HomeAssistantError,
|
||||||
match="AI Task entity ai_task.test_task_entity does not support generating text",
|
match="AI Task entity ai_task.test_task_entity does not support generating text",
|
||||||
):
|
):
|
||||||
await async_generate_text(
|
await async_generate_text(
|
||||||
@ -91,7 +94,7 @@ async def test_run_text_task_unknown_entity(
|
|||||||
"""Test running a text task with an unknown entity."""
|
"""Test running a text task with an unknown entity."""
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="AI Task entity ai_task.unknown_entity not found"
|
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
||||||
):
|
):
|
||||||
await async_generate_text(
|
await async_generate_text(
|
||||||
hass,
|
hass,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user