Simplify llm calendar tool (#137402)

* Simplify calendar tool

* Clean up exposed entities
This commit is contained in:
Paulus Schoutsen 2025-02-05 05:42:41 -05:00 committed by GitHub
parent 4d2c46959e
commit ac42c9386c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 33 deletions

View File

@ -35,13 +35,13 @@ class StatelessAssistAPI(llm.AssistAPI):
"""Return the prompt for the exposed entities.""" """Return the prompt for the exposed entities."""
prompt = [] prompt = []
if exposed_entities: if exposed_entities and exposed_entities["entities"]:
prompt.append( prompt.append(
"An overview of the areas and the devices in this smart home:" "An overview of the areas and the devices in this smart home:"
) )
entities = [ entities = [
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS} {k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
for entity_info in exposed_entities.values() for entity_info in exposed_entities["entities"].values()
] ]
prompt.append(yaml_util.dump(list(entities))) prompt.append(yaml_util.dump(list(entities)))

View File

@ -329,7 +329,7 @@ class AssistAPI(API):
def _async_get_api_prompt( def _async_get_api_prompt(
self, llm_context: LLMContext, exposed_entities: dict | None self, llm_context: LLMContext, exposed_entities: dict | None
) -> str: ) -> str:
if not exposed_entities: if not exposed_entities or not exposed_entities["entities"]:
return ( return (
"Only if the user wants to control a device, tell them to expose entities " "Only if the user wants to control a device, tell them to expose entities "
"to their voice assistant in Home Assistant." "to their voice assistant in Home Assistant."
@ -392,11 +392,11 @@ class AssistAPI(API):
"""Return the prompt for the API for exposed entities.""" """Return the prompt for the API for exposed entities."""
prompt = [] prompt = []
if exposed_entities: if exposed_entities and exposed_entities["entities"]:
prompt.append( prompt.append(
"An overview of the areas and the devices in this smart home:" "An overview of the areas and the devices in this smart home:"
) )
prompt.append(yaml_util.dump(list(exposed_entities.values()))) prompt.append(yaml_util.dump(list(exposed_entities["entities"].values())))
return prompt return prompt
@ -428,8 +428,9 @@ class AssistAPI(API):
exposed_domains: set[str] | None = None exposed_domains: set[str] | None = None
if exposed_entities is not None: if exposed_entities is not None:
exposed_domains = { exposed_domains = {
split_entity_id(entity_id)[0] for entity_id in exposed_entities info["domain"] for info in exposed_entities["entities"].values()
} }
intent_handlers = [ intent_handlers = [
intent_handler intent_handler
for intent_handler in intent_handlers for intent_handler in intent_handlers
@ -441,25 +442,29 @@ class AssistAPI(API):
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler) IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
for intent_handler in intent_handlers for intent_handler in intent_handlers
] ]
if exposed_domains and CALENDAR_DOMAIN in exposed_domains:
tools.append(CalendarGetEventsTool())
if llm_context.assistant is not None: if exposed_entities:
for state in self.hass.states.async_all(SCRIPT_DOMAIN): if exposed_entities[CALENDAR_DOMAIN]:
if not async_should_expose( names = []
self.hass, llm_context.assistant, state.entity_id for info in exposed_entities[CALENDAR_DOMAIN].values():
): names.extend(info["names"].split(", "))
continue tools.append(CalendarGetEventsTool(names))
tools.append(ScriptTool(self.hass, state.entity_id)) tools.extend(
ScriptTool(self.hass, script_entity_id)
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
)
return tools return tools
def _get_exposed_entities( def _get_exposed_entities(
hass: HomeAssistant, assistant: str hass: HomeAssistant, assistant: str
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, dict[str, Any]]]:
"""Get exposed entities.""" """Get exposed entities.
Splits out calendars and scripts.
"""
area_registry = ar.async_get(hass) area_registry = ar.async_get(hass)
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
@ -480,12 +485,13 @@ def _get_exposed_entities(
} }
entities = {} entities = {}
data: dict[str, dict[str, Any]] = {
SCRIPT_DOMAIN: {},
CALENDAR_DOMAIN: {},
}
for state in hass.states.async_all(): for state in hass.states.async_all():
if ( if not async_should_expose(hass, assistant, state.entity_id):
not async_should_expose(hass, assistant, state.entity_id)
or state.domain == SCRIPT_DOMAIN
):
continue continue
description: str | None = None description: str | None = None
@ -532,9 +538,13 @@ def _get_exposed_entities(
}: }:
info["attributes"] = attributes info["attributes"] = attributes
entities[state.entity_id] = info if state.domain in data:
data[state.domain][state.entity_id] = info
else:
entities[state.entity_id] = info
return entities data["entities"] = entities
return data
def _selector_serializer(schema: Any) -> Any: # noqa: C901 def _selector_serializer(schema: Any) -> Any: # noqa: C901
@ -816,15 +826,18 @@ class CalendarGetEventsTool(Tool):
name = "calendar_get_events" name = "calendar_get_events"
description = ( description = (
"Get events from a calendar. " "Get events from a calendar. "
"When asked when something happens, search the whole week. " "When asked if something happens, search the whole week. "
"Results are RFC 5545 which means 'end' is exclusive." "Results are RFC 5545 which means 'end' is exclusive."
) )
parameters = vol.Schema(
{ def __init__(self, calendars: list[str]) -> None:
vol.Required("calendar"): cv.string, """Init the get events tool."""
vol.Required("range"): vol.In(["today", "week"]), self.parameters = vol.Schema(
} {
) vol.Required("calendar"): vol.In(calendars),
vol.Required("range"): vol.In(["today", "week"]),
}
)
async def async_call( async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext

View File

@ -1170,7 +1170,9 @@ async def test_selector_serializer(
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None: async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
"""Test the calendar get events tool.""" """Test the calendar get events tool."""
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
hass.states.async_set("calendar.test_calendar", "on", {"friendly_name": "Test"}) hass.states.async_set(
"calendar.test_calendar", "on", {"friendly_name": "Mock Calendar Name"}
)
async_expose_entity(hass, "conversation", "calendar.test_calendar", True) async_expose_entity(hass, "conversation", "calendar.test_calendar", True)
context = Context() context = Context()
llm_context = llm.LLMContext( llm_context = llm.LLMContext(
@ -1182,7 +1184,11 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
device_id=None, device_id=None,
) )
api = await llm.async_get_api(hass, "assist", llm_context) api = await llm.async_get_api(hass, "assist", llm_context)
assert [tool for tool in api.tools if tool.name == "calendar_get_events"] tool = next(
(tool for tool in api.tools if tool.name == "calendar_get_events"), None
)
assert tool is not None
assert tool.parameters.schema["calendar"].container == ["Mock Calendar Name"]
calls = async_mock_service( calls = async_mock_service(
hass, hass,
@ -1212,7 +1218,10 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name="calendar_get_events", tool_name="calendar_get_events",
tool_args={"calendar": "calendar.test_calendar", "range": "today"}, tool_args={
"calendar": "Mock Calendar Name",
"range": "today",
},
) )
now = dt_util.now() now = dt_util.now()
with patch("homeassistant.util.dt.now", return_value=now): with patch("homeassistant.util.dt.now", return_value=now):