mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Allow LLMs to get calendar events from exposed calendars (#136304)
This commit is contained in:
parent
414fa4125e
commit
005ae3ace6
@ -5,15 +5,20 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from functools import cache, partial
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import slugify as unicode_slug
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import UNSUPPORTED, convert
|
||||
|
||||
from homeassistant.components.calendar import (
|
||||
DOMAIN as CALENDAR_DOMAIN,
|
||||
SERVICE_GET_EVENTS,
|
||||
)
|
||||
from homeassistant.components.climate import INTENT_GET_TEMPERATURE
|
||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||
from homeassistant.components.homeassistant import async_should_expose
|
||||
@ -28,7 +33,7 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import yaml as yaml_util
|
||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
@ -415,6 +420,8 @@ class AssistAPI(API):
|
||||
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
||||
for intent_handler in intent_handlers
|
||||
]
|
||||
if exposed_domains and CALENDAR_DOMAIN in exposed_domains:
|
||||
tools.append(CalendarGetEventsTool())
|
||||
|
||||
if llm_context.assistant is not None:
|
||||
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
|
||||
@ -755,3 +762,66 @@ class ScriptTool(Tool):
|
||||
)
|
||||
|
||||
return {"success": True, "result": result}
|
||||
|
||||
|
||||
class CalendarGetEventsTool(Tool):
|
||||
"""LLM Tool allowing querying a calendar."""
|
||||
|
||||
name = "calendar_get_events"
|
||||
description = (
|
||||
"Get events from a calendar. "
|
||||
"When asked when something happens, search the whole week. "
|
||||
"Results are RFC 5545 which means 'end' is exclusive."
|
||||
)
|
||||
parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("calendar"): cv.string,
|
||||
vol.Required("range"): vol.In(["today", "week"]),
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Query a calendar."""
|
||||
data = self.parameters(tool_input.tool_args)
|
||||
result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=data["calendar"],
|
||||
domains=[CALENDAR_DOMAIN],
|
||||
assistant=llm_context.assistant,
|
||||
),
|
||||
)
|
||||
if not result.is_match:
|
||||
return {"success": False, "error": "Calendar not found"}
|
||||
|
||||
entity_id = result.states[0].entity_id
|
||||
if data["range"] == "today":
|
||||
start = dt_util.now()
|
||||
end = dt_util.start_of_local_day() + timedelta(days=1)
|
||||
elif data["range"] == "week":
|
||||
start = dt_util.now()
|
||||
end = dt_util.start_of_local_day() + timedelta(days=7)
|
||||
|
||||
service_data = {
|
||||
"entity_id": entity_id,
|
||||
"start_date_time": start.isoformat(),
|
||||
"end_date_time": end.isoformat(),
|
||||
}
|
||||
|
||||
service_result = await hass.services.async_call(
|
||||
CALENDAR_DOMAIN,
|
||||
SERVICE_GET_EVENTS,
|
||||
service_data,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
events = [
|
||||
event if "T" in event["start"] else {**event, "all_day": True}
|
||||
for event in cast(dict, service_result)[entity_id]["events"]
|
||||
]
|
||||
|
||||
return {"success": True, "result": events}
|
||||
|
@ -1,15 +1,17 @@
|
||||
"""Tests for the llm helpers."""
|
||||
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import calendar
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.components.script.config import ScriptConfig
|
||||
from homeassistant.core import Context, HomeAssistant, State
|
||||
from homeassistant.core import Context, HomeAssistant, State, SupportsResponse
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
@ -22,8 +24,9 @@ from homeassistant.helpers import (
|
||||
selector,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.common import MockConfigEntry, async_mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -1162,3 +1165,96 @@ async def test_selector_serializer(
|
||||
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
|
||||
"type": "string"
|
||||
}
|
||||
|
||||
|
||||
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
"""Test the calendar get events tool."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
hass.states.async_set("calendar.test_calendar", "on", {"friendly_name": "Test"})
|
||||
async_expose_entity(hass, "conversation", "calendar.test_calendar", True)
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
context=context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
assert [tool for tool in api.tools if tool.name == "calendar_get_events"]
|
||||
|
||||
calls = async_mock_service(
|
||||
hass,
|
||||
domain=calendar.DOMAIN,
|
||||
service=calendar.SERVICE_GET_EVENTS,
|
||||
schema=calendar.SERVICE_GET_EVENTS_SCHEMA,
|
||||
response={
|
||||
"calendar.test_calendar": {
|
||||
"events": [
|
||||
{
|
||||
"start": "2025-09-17",
|
||||
"end": "2025-09-18",
|
||||
"summary": "Home Assistant 12th birthday",
|
||||
"description": "",
|
||||
},
|
||||
{
|
||||
"start": "2025-09-17T14:00:00-05:00",
|
||||
"end": "2025-09-18T15:00:00-05:00",
|
||||
"summary": "Champagne",
|
||||
"description": "",
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name="calendar_get_events",
|
||||
tool_args={"calendar": "calendar.test_calendar", "range": "today"},
|
||||
)
|
||||
now = dt_util.now()
|
||||
with patch("homeassistant.util.dt.now", return_value=now):
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == calendar.DOMAIN
|
||||
assert call.service == calendar.SERVICE_GET_EVENTS
|
||||
assert call.data == {
|
||||
"entity_id": ["calendar.test_calendar"],
|
||||
"start_date_time": now,
|
||||
"end_date_time": dt_util.start_of_local_day() + timedelta(days=1),
|
||||
}
|
||||
|
||||
assert response == {
|
||||
"success": True,
|
||||
"result": [
|
||||
{
|
||||
"start": "2025-09-17",
|
||||
"end": "2025-09-18",
|
||||
"summary": "Home Assistant 12th birthday",
|
||||
"description": "",
|
||||
"all_day": True,
|
||||
},
|
||||
{
|
||||
"start": "2025-09-17T14:00:00-05:00",
|
||||
"end": "2025-09-18T15:00:00-05:00",
|
||||
"summary": "Champagne",
|
||||
"description": "",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
tool_input.tool_args["range"] = "week"
|
||||
with patch("homeassistant.util.dt.now", return_value=now):
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
assert len(calls) == 2
|
||||
call = calls[1]
|
||||
assert call.data == {
|
||||
"entity_id": ["calendar.test_calendar"],
|
||||
"start_date_time": now,
|
||||
"end_date_time": dt_util.start_of_local_day() + timedelta(days=7),
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user