Simplify llm calendar tool (#137402)

* Simplify calendar tool

* Clean up exposed entities
This commit is contained in:
Paulus Schoutsen 2025-02-05 05:42:41 -05:00 committed by GitHub
parent 4d2c46959e
commit ac42c9386c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 33 deletions

View File

@ -35,13 +35,13 @@ class StatelessAssistAPI(llm.AssistAPI):
"""Return the prompt for the exposed entities."""
prompt = []
if exposed_entities:
if exposed_entities and exposed_entities["entities"]:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
entities = [
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
for entity_info in exposed_entities.values()
for entity_info in exposed_entities["entities"].values()
]
prompt.append(yaml_util.dump(list(entities)))

View File

@ -329,7 +329,7 @@ class AssistAPI(API):
def _async_get_api_prompt(
self, llm_context: LLMContext, exposed_entities: dict | None
) -> str:
if not exposed_entities:
if not exposed_entities or not exposed_entities["entities"]:
return (
"Only if the user wants to control a device, tell them to expose entities "
"to their voice assistant in Home Assistant."
@ -392,11 +392,11 @@ class AssistAPI(API):
"""Return the prompt for the API for exposed entities."""
prompt = []
if exposed_entities:
if exposed_entities and exposed_entities["entities"]:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
prompt.append(yaml_util.dump(list(exposed_entities.values())))
prompt.append(yaml_util.dump(list(exposed_entities["entities"].values())))
return prompt
@ -428,8 +428,9 @@ class AssistAPI(API):
exposed_domains: set[str] | None = None
if exposed_entities is not None:
exposed_domains = {
split_entity_id(entity_id)[0] for entity_id in exposed_entities
info["domain"] for info in exposed_entities["entities"].values()
}
intent_handlers = [
intent_handler
for intent_handler in intent_handlers
@ -441,25 +442,29 @@ class AssistAPI(API):
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
for intent_handler in intent_handlers
]
if exposed_domains and CALENDAR_DOMAIN in exposed_domains:
tools.append(CalendarGetEventsTool())
if llm_context.assistant is not None:
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
if not async_should_expose(
self.hass, llm_context.assistant, state.entity_id
):
continue
if exposed_entities:
if exposed_entities[CALENDAR_DOMAIN]:
names = []
for info in exposed_entities[CALENDAR_DOMAIN].values():
names.extend(info["names"].split(", "))
tools.append(CalendarGetEventsTool(names))
tools.append(ScriptTool(self.hass, state.entity_id))
tools.extend(
ScriptTool(self.hass, script_entity_id)
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
)
return tools
def _get_exposed_entities(
hass: HomeAssistant, assistant: str
) -> dict[str, dict[str, Any]]:
"""Get exposed entities."""
) -> dict[str, dict[str, dict[str, Any]]]:
"""Get exposed entities.
Splits out calendars and scripts.
"""
area_registry = ar.async_get(hass)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
@ -480,12 +485,13 @@ def _get_exposed_entities(
}
entities = {}
data: dict[str, dict[str, Any]] = {
SCRIPT_DOMAIN: {},
CALENDAR_DOMAIN: {},
}
for state in hass.states.async_all():
if (
not async_should_expose(hass, assistant, state.entity_id)
or state.domain == SCRIPT_DOMAIN
):
if not async_should_expose(hass, assistant, state.entity_id):
continue
description: str | None = None
@ -532,9 +538,13 @@ def _get_exposed_entities(
}:
info["attributes"] = attributes
entities[state.entity_id] = info
if state.domain in data:
data[state.domain][state.entity_id] = info
else:
entities[state.entity_id] = info
return entities
data["entities"] = entities
return data
def _selector_serializer(schema: Any) -> Any: # noqa: C901
@ -816,15 +826,18 @@ class CalendarGetEventsTool(Tool):
name = "calendar_get_events"
description = (
"Get events from a calendar. "
"When asked when something happens, search the whole week. "
"When asked if something happens, search the whole week. "
"Results are RFC 5545 which means 'end' is exclusive."
)
parameters = vol.Schema(
{
vol.Required("calendar"): cv.string,
vol.Required("range"): vol.In(["today", "week"]),
}
)
def __init__(self, calendars: list[str]) -> None:
"""Init the get events tool."""
self.parameters = vol.Schema(
{
vol.Required("calendar"): vol.In(calendars),
vol.Required("range"): vol.In(["today", "week"]),
}
)
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext

View File

@ -1170,7 +1170,9 @@ async def test_selector_serializer(
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
"""Test the calendar get events tool."""
assert await async_setup_component(hass, "homeassistant", {})
hass.states.async_set("calendar.test_calendar", "on", {"friendly_name": "Test"})
hass.states.async_set(
"calendar.test_calendar", "on", {"friendly_name": "Mock Calendar Name"}
)
async_expose_entity(hass, "conversation", "calendar.test_calendar", True)
context = Context()
llm_context = llm.LLMContext(
@ -1182,7 +1184,11 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
device_id=None,
)
api = await llm.async_get_api(hass, "assist", llm_context)
assert [tool for tool in api.tools if tool.name == "calendar_get_events"]
tool = next(
(tool for tool in api.tools if tool.name == "calendar_get_events"), None
)
assert tool is not None
assert tool.parameters.schema["calendar"].container == ["Mock Calendar Name"]
calls = async_mock_service(
hass,
@ -1212,7 +1218,10 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
tool_input = llm.ToolInput(
tool_name="calendar_get_events",
tool_args={"calendar": "calendar.test_calendar", "range": "today"},
tool_args={
"calendar": "Mock Calendar Name",
"range": "today",
},
)
now = dt_util.now()
with patch("homeassistant.util.dt.now", return_value=now):