Allow LLMs to get calendar events from exposed calendars (#136304)

This commit is contained in:
Paulus Schoutsen 2025-01-23 17:54:04 -05:00 committed by GitHub
parent 414fa4125e
commit 005ae3ace6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 170 additions and 4 deletions

View File

@ -5,15 +5,20 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from decimal import Decimal
from enum import Enum
from functools import cache, partial
from typing import Any
from typing import Any, cast
import slugify as unicode_slug
import voluptuous as vol
from voluptuous_openapi import UNSUPPORTED, convert
from homeassistant.components.calendar import (
DOMAIN as CALENDAR_DOMAIN,
SERVICE_GET_EVENTS,
)
from homeassistant.components.climate import INTENT_GET_TEMPERATURE
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant import async_should_expose
@ -28,7 +33,7 @@ from homeassistant.const import (
)
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml as yaml_util
from homeassistant.util import dt as dt_util, yaml as yaml_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
@ -415,6 +420,8 @@ class AssistAPI(API):
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
for intent_handler in intent_handlers
]
if exposed_domains and CALENDAR_DOMAIN in exposed_domains:
tools.append(CalendarGetEventsTool())
if llm_context.assistant is not None:
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
@ -755,3 +762,66 @@ class ScriptTool(Tool):
)
return {"success": True, "result": result}
class CalendarGetEventsTool(Tool):
"""LLM Tool allowing querying a calendar."""
name = "calendar_get_events"
description = (
"Get events from a calendar. "
"When asked when something happens, search the whole week. "
"Results are RFC 5545 which means 'end' is exclusive."
)
parameters = vol.Schema(
{
vol.Required("calendar"): cv.string,
vol.Required("range"): vol.In(["today", "week"]),
}
)
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Query a calendar."""
data = self.parameters(tool_input.tool_args)
result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=data["calendar"],
domains=[CALENDAR_DOMAIN],
assistant=llm_context.assistant,
),
)
if not result.is_match:
return {"success": False, "error": "Calendar not found"}
entity_id = result.states[0].entity_id
if data["range"] == "today":
start = dt_util.now()
end = dt_util.start_of_local_day() + timedelta(days=1)
elif data["range"] == "week":
start = dt_util.now()
end = dt_util.start_of_local_day() + timedelta(days=7)
service_data = {
"entity_id": entity_id,
"start_date_time": start.isoformat(),
"end_date_time": end.isoformat(),
}
service_result = await hass.services.async_call(
CALENDAR_DOMAIN,
SERVICE_GET_EVENTS,
service_data,
context=llm_context.context,
blocking=True,
return_response=True,
)
events = [
event if "T" in event["start"] else {**event, "all_day": True}
for event in cast(dict, service_result)[entity_id]["events"]
]
return {"success": True, "result": events}

View File

@ -1,15 +1,17 @@
"""Tests for the llm helpers."""
from datetime import timedelta
from decimal import Decimal
from unittest.mock import patch
import pytest
import voluptuous as vol
from homeassistant.components import calendar
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.script.config import ScriptConfig
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.core import Context, HomeAssistant, State, SupportsResponse
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
@ -22,8 +24,9 @@ from homeassistant.helpers import (
selector,
)
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from tests.common import MockConfigEntry
from tests.common import MockConfigEntry, async_mock_service
@pytest.fixture
@ -1162,3 +1165,96 @@ async def test_selector_serializer(
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
"type": "string"
}
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
"""Test the calendar get events tool."""
assert await async_setup_component(hass, "homeassistant", {})
hass.states.async_set("calendar.test_calendar", "on", {"friendly_name": "Test"})
async_expose_entity(hass, "conversation", "calendar.test_calendar", True)
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
context=context,
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
api = await llm.async_get_api(hass, "assist", llm_context)
assert [tool for tool in api.tools if tool.name == "calendar_get_events"]
calls = async_mock_service(
hass,
domain=calendar.DOMAIN,
service=calendar.SERVICE_GET_EVENTS,
schema=calendar.SERVICE_GET_EVENTS_SCHEMA,
response={
"calendar.test_calendar": {
"events": [
{
"start": "2025-09-17",
"end": "2025-09-18",
"summary": "Home Assistant 12th birthday",
"description": "",
},
{
"start": "2025-09-17T14:00:00-05:00",
"end": "2025-09-18T15:00:00-05:00",
"summary": "Champagne",
"description": "",
},
]
}
},
supports_response=SupportsResponse.ONLY,
)
tool_input = llm.ToolInput(
tool_name="calendar_get_events",
tool_args={"calendar": "calendar.test_calendar", "range": "today"},
)
now = dt_util.now()
with patch("homeassistant.util.dt.now", return_value=now):
response = await api.async_call_tool(tool_input)
assert len(calls) == 1
call = calls[0]
assert call.domain == calendar.DOMAIN
assert call.service == calendar.SERVICE_GET_EVENTS
assert call.data == {
"entity_id": ["calendar.test_calendar"],
"start_date_time": now,
"end_date_time": dt_util.start_of_local_day() + timedelta(days=1),
}
assert response == {
"success": True,
"result": [
{
"start": "2025-09-17",
"end": "2025-09-18",
"summary": "Home Assistant 12th birthday",
"description": "",
"all_day": True,
},
{
"start": "2025-09-17T14:00:00-05:00",
"end": "2025-09-18T15:00:00-05:00",
"summary": "Champagne",
"description": "",
},
],
}
tool_input.tool_args["range"] = "week"
with patch("homeassistant.util.dt.now", return_value=now):
response = await api.async_call_tool(tool_input)
assert len(calls) == 2
call = calls[1]
assert call.data == {
"entity_id": ["calendar.test_calendar"],
"start_date_time": now,
"end_date_time": dt_util.start_of_local_day() + timedelta(days=7),
}