mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 13:57:10 +00:00
Add an LLM tool for fetching todo list items (#143777)
* Add a tool for fetching todo list items * Simplify the todo list interface by adding an "all" status * Update prompt to improve performance on smaller models
This commit is contained in:
parent
bdd9099294
commit
b16151ac6d
@ -95,6 +95,12 @@ TODO_ITEM_FIELD_SCHEMA = {
|
|||||||
vol.Optional(desc.service_field): desc.validation for desc in TODO_ITEM_FIELDS
|
vol.Optional(desc.service_field): desc.validation for desc in TODO_ITEM_FIELDS
|
||||||
}
|
}
|
||||||
TODO_ITEM_FIELD_VALIDATIONS = [cv.has_at_most_one_key(ATTR_DUE_DATE, ATTR_DUE_DATETIME)]
|
TODO_ITEM_FIELD_VALIDATIONS = [cv.has_at_most_one_key(ATTR_DUE_DATE, ATTR_DUE_DATETIME)]
|
||||||
|
TODO_SERVICE_GET_ITEMS_SCHEMA = {
|
||||||
|
vol.Optional(ATTR_STATUS): vol.All(
|
||||||
|
cv.ensure_list,
|
||||||
|
[vol.In({TodoItemStatus.NEEDS_ACTION, TodoItemStatus.COMPLETED})],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _validate_supported_features(
|
def _validate_supported_features(
|
||||||
@ -177,14 +183,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
)
|
)
|
||||||
component.async_register_entity_service(
|
component.async_register_entity_service(
|
||||||
TodoServices.GET_ITEMS,
|
TodoServices.GET_ITEMS,
|
||||||
cv.make_entity_service_schema(
|
cv.make_entity_service_schema(TODO_SERVICE_GET_ITEMS_SCHEMA),
|
||||||
{
|
|
||||||
vol.Optional(ATTR_STATUS): vol.All(
|
|
||||||
cv.ensure_list,
|
|
||||||
[vol.In({TodoItemStatus.NEEDS_ACTION, TodoItemStatus.COMPLETED})],
|
|
||||||
),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_async_get_todo_items,
|
_async_get_todo_items,
|
||||||
supports_response=SupportsResponse.ONLY,
|
supports_response=SupportsResponse.ONLY,
|
||||||
)
|
)
|
||||||
|
@ -24,6 +24,7 @@ from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
|||||||
from homeassistant.components.homeassistant import async_should_expose
|
from homeassistant.components.homeassistant import async_should_expose
|
||||||
from homeassistant.components.intent import async_device_supports_timers
|
from homeassistant.components.intent import async_device_supports_timers
|
||||||
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
|
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
|
||||||
|
from homeassistant.components.todo import DOMAIN as TODO_DOMAIN, TodoServices
|
||||||
from homeassistant.components.weather import INTENT_GET_WEATHER
|
from homeassistant.components.weather import INTENT_GET_WEATHER
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DOMAIN,
|
ATTR_DOMAIN,
|
||||||
@ -577,6 +578,14 @@ class AssistAPI(API):
|
|||||||
names.extend(info["names"].split(", "))
|
names.extend(info["names"].split(", "))
|
||||||
tools.append(CalendarGetEventsTool(names))
|
tools.append(CalendarGetEventsTool(names))
|
||||||
|
|
||||||
|
if exposed_domains is not None and TODO_DOMAIN in exposed_domains:
|
||||||
|
names = []
|
||||||
|
for info in exposed_entities["entities"].values():
|
||||||
|
if info["domain"] != TODO_DOMAIN:
|
||||||
|
continue
|
||||||
|
names.extend(info["names"].split(", "))
|
||||||
|
tools.append(TodoGetItemsTool(names))
|
||||||
|
|
||||||
tools.extend(
|
tools.extend(
|
||||||
ScriptTool(self.hass, script_entity_id)
|
ScriptTool(self.hass, script_entity_id)
|
||||||
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
|
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
|
||||||
@ -1024,6 +1033,65 @@ class CalendarGetEventsTool(Tool):
|
|||||||
return {"success": True, "result": events}
|
return {"success": True, "result": events}
|
||||||
|
|
||||||
|
|
||||||
|
class TodoGetItemsTool(Tool):
|
||||||
|
"""LLM Tool allowing querying a to-do list."""
|
||||||
|
|
||||||
|
name = "todo_get_items"
|
||||||
|
description = (
|
||||||
|
"Query a to-do list to find out what items are on it. "
|
||||||
|
"Use this to answer questions like 'What's on my task list?' or 'Read my grocery list'. "
|
||||||
|
"Filters items by status (needs_action, completed, all)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, todo_lists: list[str]) -> None:
|
||||||
|
"""Init the get items tool."""
|
||||||
|
self.parameters = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("todo_list"): vol.In(todo_lists),
|
||||||
|
vol.Optional(
|
||||||
|
"status",
|
||||||
|
description="Filter returned items by status, by default returns incomplete items",
|
||||||
|
default="needs_action",
|
||||||
|
): vol.In(["needs_action", "completed", "all"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_call(
|
||||||
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
|
) -> JsonObjectType:
|
||||||
|
"""Query a to-do list."""
|
||||||
|
data = self.parameters(tool_input.tool_args)
|
||||||
|
result = intent.async_match_targets(
|
||||||
|
hass,
|
||||||
|
intent.MatchTargetsConstraints(
|
||||||
|
name=data["todo_list"],
|
||||||
|
domains=[TODO_DOMAIN],
|
||||||
|
assistant=llm_context.assistant,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not result.is_match:
|
||||||
|
return {"success": False, "error": "To-do list not found"}
|
||||||
|
entity_id = result.states[0].entity_id
|
||||||
|
service_data: dict[str, Any] = {"entity_id": entity_id}
|
||||||
|
if status := data.get("status"):
|
||||||
|
if status == "all":
|
||||||
|
service_data["status"] = ["needs_action", "completed"]
|
||||||
|
else:
|
||||||
|
service_data["status"] = [status]
|
||||||
|
service_result = await hass.services.async_call(
|
||||||
|
TODO_DOMAIN,
|
||||||
|
TodoServices.GET_ITEMS,
|
||||||
|
service_data,
|
||||||
|
context=llm_context.context,
|
||||||
|
blocking=True,
|
||||||
|
return_response=True,
|
||||||
|
)
|
||||||
|
if not service_result:
|
||||||
|
return {"success": False, "error": "To-do list not found"}
|
||||||
|
items = cast(dict, service_result)[entity_id]["items"]
|
||||||
|
return {"success": True, "result": items}
|
||||||
|
|
||||||
|
|
||||||
class GetLiveContextTool(Tool):
|
class GetLiveContextTool(Tool):
|
||||||
"""Tool for getting the current state of exposed entities.
|
"""Tool for getting the current state of exposed entities.
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import calendar
|
from homeassistant.components import calendar, todo
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.components.intent import async_register_timer_handler
|
from homeassistant.components.intent import async_register_timer_handler
|
||||||
from homeassistant.components.script.config import ScriptConfig
|
from homeassistant.components.script.config import ScriptConfig
|
||||||
@ -1332,6 +1332,118 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_todo_get_items_tool(hass: HomeAssistant) -> None:
|
||||||
|
"""Test the todo get items tool."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "todo", {})
|
||||||
|
hass.states.async_set(
|
||||||
|
"todo.test_list", "0", {"friendly_name": "Mock Todo List Name"}
|
||||||
|
)
|
||||||
|
async_expose_entity(hass, "conversation", "todo.test_list", True)
|
||||||
|
context = Context()
|
||||||
|
llm_context = llm.LLMContext(
|
||||||
|
platform="test_platform",
|
||||||
|
context=context,
|
||||||
|
user_prompt="test_text",
|
||||||
|
language="*",
|
||||||
|
assistant="conversation",
|
||||||
|
device_id=None,
|
||||||
|
)
|
||||||
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
tool = next((tool for tool in api.tools if tool.name == "todo_get_items"), None)
|
||||||
|
assert tool is not None
|
||||||
|
assert tool.parameters.schema["todo_list"].container == ["Mock Todo List Name"]
|
||||||
|
|
||||||
|
calls = async_mock_service(
|
||||||
|
hass,
|
||||||
|
domain=todo.DOMAIN,
|
||||||
|
service=todo.TodoServices.GET_ITEMS,
|
||||||
|
schema=cv.make_entity_service_schema(todo.TODO_SERVICE_GET_ITEMS_SCHEMA),
|
||||||
|
response={
|
||||||
|
"todo.test_list": {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"uid": "1234",
|
||||||
|
"summary": "Buy milk",
|
||||||
|
"status": "needs_action",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"uid": "5678",
|
||||||
|
"summary": "Call mom",
|
||||||
|
"status": "needs_action",
|
||||||
|
"due": "2025-09-17",
|
||||||
|
"description": "Remember birthday",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test without status filter (defaults to needs_action)
|
||||||
|
result = await tool.async_call(
|
||||||
|
hass,
|
||||||
|
llm.ToolInput("todo_get_items", {"todo_list": "Mock Todo List Name"}),
|
||||||
|
llm_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].data == {
|
||||||
|
"entity_id": ["todo.test_list"],
|
||||||
|
"status": ["needs_action"],
|
||||||
|
}
|
||||||
|
assert result == {
|
||||||
|
"success": True,
|
||||||
|
"result": [
|
||||||
|
{
|
||||||
|
"uid": "1234",
|
||||||
|
"status": "needs_action",
|
||||||
|
"summary": "Buy milk",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"uid": "5678",
|
||||||
|
"status": "needs_action",
|
||||||
|
"summary": "Call mom",
|
||||||
|
"due": "2025-09-17",
|
||||||
|
"description": "Remember birthday",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test that the status filter is passed correctly to the service call.
|
||||||
|
# We don't assert on the response since it is fixed above.
|
||||||
|
calls.clear()
|
||||||
|
result = await tool.async_call(
|
||||||
|
hass,
|
||||||
|
llm.ToolInput(
|
||||||
|
"todo_get_items",
|
||||||
|
{"todo_list": "Mock Todo List Name", "status": "completed"},
|
||||||
|
),
|
||||||
|
llm_context,
|
||||||
|
)
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].data == {
|
||||||
|
"entity_id": ["todo.test_list"],
|
||||||
|
"status": ["completed"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test that the status filter is passed correctly to the service call.
|
||||||
|
# We don't assert on the response since it is fixed above.
|
||||||
|
calls.clear()
|
||||||
|
result = await tool.async_call(
|
||||||
|
hass,
|
||||||
|
llm.ToolInput(
|
||||||
|
"todo_get_items",
|
||||||
|
{"todo_list": "Mock Todo List Name", "status": "all"},
|
||||||
|
),
|
||||||
|
llm_context,
|
||||||
|
)
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].data == {
|
||||||
|
"entity_id": ["todo.test_list"],
|
||||||
|
"status": ["needs_action", "completed"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_no_tools_exposed(hass: HomeAssistant) -> None:
|
async def test_no_tools_exposed(hass: HomeAssistant) -> None:
|
||||||
"""Test that tools are not exposed when no entities are exposed."""
|
"""Test that tools are not exposed when no entities are exposed."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user