mirror of
https://github.com/home-assistant/core.git
synced 2025-07-07 21:37:07 +00:00
Add an LLM tool for fetching todo list items (#143777)
* Add a tool for fetching todo list items * Simplify the todo list interface by adding an "all" status * Update prompt to improve performance on smaller models
This commit is contained in:
parent
bdd9099294
commit
b16151ac6d
@ -95,6 +95,12 @@ TODO_ITEM_FIELD_SCHEMA = {
|
||||
vol.Optional(desc.service_field): desc.validation for desc in TODO_ITEM_FIELDS
|
||||
}
|
||||
TODO_ITEM_FIELD_VALIDATIONS = [cv.has_at_most_one_key(ATTR_DUE_DATE, ATTR_DUE_DATETIME)]
|
||||
TODO_SERVICE_GET_ITEMS_SCHEMA = {
|
||||
vol.Optional(ATTR_STATUS): vol.All(
|
||||
cv.ensure_list,
|
||||
[vol.In({TodoItemStatus.NEEDS_ACTION, TodoItemStatus.COMPLETED})],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _validate_supported_features(
|
||||
@ -177,14 +183,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
)
|
||||
component.async_register_entity_service(
|
||||
TodoServices.GET_ITEMS,
|
||||
cv.make_entity_service_schema(
|
||||
{
|
||||
vol.Optional(ATTR_STATUS): vol.All(
|
||||
cv.ensure_list,
|
||||
[vol.In({TodoItemStatus.NEEDS_ACTION, TodoItemStatus.COMPLETED})],
|
||||
),
|
||||
}
|
||||
),
|
||||
cv.make_entity_service_schema(TODO_SERVICE_GET_ITEMS_SCHEMA),
|
||||
_async_get_todo_items,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||
from homeassistant.components.homeassistant import async_should_expose
|
||||
from homeassistant.components.intent import async_device_supports_timers
|
||||
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
|
||||
from homeassistant.components.todo import DOMAIN as TODO_DOMAIN, TodoServices
|
||||
from homeassistant.components.weather import INTENT_GET_WEATHER
|
||||
from homeassistant.const import (
|
||||
ATTR_DOMAIN,
|
||||
@ -577,6 +578,14 @@ class AssistAPI(API):
|
||||
names.extend(info["names"].split(", "))
|
||||
tools.append(CalendarGetEventsTool(names))
|
||||
|
||||
if exposed_domains is not None and TODO_DOMAIN in exposed_domains:
|
||||
names = []
|
||||
for info in exposed_entities["entities"].values():
|
||||
if info["domain"] != TODO_DOMAIN:
|
||||
continue
|
||||
names.extend(info["names"].split(", "))
|
||||
tools.append(TodoGetItemsTool(names))
|
||||
|
||||
tools.extend(
|
||||
ScriptTool(self.hass, script_entity_id)
|
||||
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
|
||||
@ -1024,6 +1033,65 @@ class CalendarGetEventsTool(Tool):
|
||||
return {"success": True, "result": events}
|
||||
|
||||
|
||||
class TodoGetItemsTool(Tool):
|
||||
"""LLM Tool allowing querying a to-do list."""
|
||||
|
||||
name = "todo_get_items"
|
||||
description = (
|
||||
"Query a to-do list to find out what items are on it. "
|
||||
"Use this to answer questions like 'What's on my task list?' or 'Read my grocery list'. "
|
||||
"Filters items by status (needs_action, completed, all)."
|
||||
)
|
||||
|
||||
def __init__(self, todo_lists: list[str]) -> None:
|
||||
"""Init the get items tool."""
|
||||
self.parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("todo_list"): vol.In(todo_lists),
|
||||
vol.Optional(
|
||||
"status",
|
||||
description="Filter returned items by status, by default returns incomplete items",
|
||||
default="needs_action",
|
||||
): vol.In(["needs_action", "completed", "all"]),
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Query a to-do list."""
|
||||
data = self.parameters(tool_input.tool_args)
|
||||
result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=data["todo_list"],
|
||||
domains=[TODO_DOMAIN],
|
||||
assistant=llm_context.assistant,
|
||||
),
|
||||
)
|
||||
if not result.is_match:
|
||||
return {"success": False, "error": "To-do list not found"}
|
||||
entity_id = result.states[0].entity_id
|
||||
service_data: dict[str, Any] = {"entity_id": entity_id}
|
||||
if status := data.get("status"):
|
||||
if status == "all":
|
||||
service_data["status"] = ["needs_action", "completed"]
|
||||
else:
|
||||
service_data["status"] = [status]
|
||||
service_result = await hass.services.async_call(
|
||||
TODO_DOMAIN,
|
||||
TodoServices.GET_ITEMS,
|
||||
service_data,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
if not service_result:
|
||||
return {"success": False, "error": "To-do list not found"}
|
||||
items = cast(dict, service_result)[entity_id]["items"]
|
||||
return {"success": True, "result": items}
|
||||
|
||||
|
||||
class GetLiveContextTool(Tool):
|
||||
"""Tool for getting the current state of exposed entities.
|
||||
|
||||
|
@ -7,7 +7,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import calendar
|
||||
from homeassistant.components import calendar, todo
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.components.script.config import ScriptConfig
|
||||
@ -1332,6 +1332,118 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
}
|
||||
|
||||
|
||||
async def test_todo_get_items_tool(hass: HomeAssistant) -> None:
|
||||
"""Test the todo get items tool."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "todo", {})
|
||||
hass.states.async_set(
|
||||
"todo.test_list", "0", {"friendly_name": "Mock Todo List Name"}
|
||||
)
|
||||
async_expose_entity(hass, "conversation", "todo.test_list", True)
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
context=context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
tool = next((tool for tool in api.tools if tool.name == "todo_get_items"), None)
|
||||
assert tool is not None
|
||||
assert tool.parameters.schema["todo_list"].container == ["Mock Todo List Name"]
|
||||
|
||||
calls = async_mock_service(
|
||||
hass,
|
||||
domain=todo.DOMAIN,
|
||||
service=todo.TodoServices.GET_ITEMS,
|
||||
schema=cv.make_entity_service_schema(todo.TODO_SERVICE_GET_ITEMS_SCHEMA),
|
||||
response={
|
||||
"todo.test_list": {
|
||||
"items": [
|
||||
{
|
||||
"uid": "1234",
|
||||
"summary": "Buy milk",
|
||||
"status": "needs_action",
|
||||
},
|
||||
{
|
||||
"uid": "5678",
|
||||
"summary": "Call mom",
|
||||
"status": "needs_action",
|
||||
"due": "2025-09-17",
|
||||
"description": "Remember birthday",
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Test without status filter (defaults to needs_action)
|
||||
result = await tool.async_call(
|
||||
hass,
|
||||
llm.ToolInput("todo_get_items", {"todo_list": "Mock Todo List Name"}),
|
||||
llm_context,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data == {
|
||||
"entity_id": ["todo.test_list"],
|
||||
"status": ["needs_action"],
|
||||
}
|
||||
assert result == {
|
||||
"success": True,
|
||||
"result": [
|
||||
{
|
||||
"uid": "1234",
|
||||
"status": "needs_action",
|
||||
"summary": "Buy milk",
|
||||
},
|
||||
{
|
||||
"uid": "5678",
|
||||
"status": "needs_action",
|
||||
"summary": "Call mom",
|
||||
"due": "2025-09-17",
|
||||
"description": "Remember birthday",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Test that the status filter is passed correctly to the service call.
|
||||
# We don't assert on the response since it is fixed above.
|
||||
calls.clear()
|
||||
result = await tool.async_call(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
"todo_get_items",
|
||||
{"todo_list": "Mock Todo List Name", "status": "completed"},
|
||||
),
|
||||
llm_context,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data == {
|
||||
"entity_id": ["todo.test_list"],
|
||||
"status": ["completed"],
|
||||
}
|
||||
|
||||
# Test that the status filter is passed correctly to the service call.
|
||||
# We don't assert on the response since it is fixed above.
|
||||
calls.clear()
|
||||
result = await tool.async_call(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
"todo_get_items",
|
||||
{"todo_list": "Mock Todo List Name", "status": "all"},
|
||||
),
|
||||
llm_context,
|
||||
)
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data == {
|
||||
"entity_id": ["todo.test_list"],
|
||||
"status": ["needs_action", "completed"],
|
||||
}
|
||||
|
||||
|
||||
async def test_no_tools_exposed(hass: HomeAssistant) -> None:
|
||||
"""Test that tools are not exposed when no entities are exposed."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
Loading…
x
Reference in New Issue
Block a user