mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Simplify llm calendar tool (#137402)
* Simplify calendar tool * Clean up exposed entities
This commit is contained in:
parent
4d2c46959e
commit
ac42c9386c
@ -35,13 +35,13 @@ class StatelessAssistAPI(llm.AssistAPI):
|
||||
"""Return the prompt for the exposed entities."""
|
||||
prompt = []
|
||||
|
||||
if exposed_entities:
|
||||
if exposed_entities and exposed_entities["entities"]:
|
||||
prompt.append(
|
||||
"An overview of the areas and the devices in this smart home:"
|
||||
)
|
||||
entities = [
|
||||
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
|
||||
for entity_info in exposed_entities.values()
|
||||
for entity_info in exposed_entities["entities"].values()
|
||||
]
|
||||
prompt.append(yaml_util.dump(list(entities)))
|
||||
|
||||
|
@ -329,7 +329,7 @@ class AssistAPI(API):
|
||||
def _async_get_api_prompt(
|
||||
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||
) -> str:
|
||||
if not exposed_entities:
|
||||
if not exposed_entities or not exposed_entities["entities"]:
|
||||
return (
|
||||
"Only if the user wants to control a device, tell them to expose entities "
|
||||
"to their voice assistant in Home Assistant."
|
||||
@ -392,11 +392,11 @@ class AssistAPI(API):
|
||||
"""Return the prompt for the API for exposed entities."""
|
||||
prompt = []
|
||||
|
||||
if exposed_entities:
|
||||
if exposed_entities and exposed_entities["entities"]:
|
||||
prompt.append(
|
||||
"An overview of the areas and the devices in this smart home:"
|
||||
)
|
||||
prompt.append(yaml_util.dump(list(exposed_entities.values())))
|
||||
prompt.append(yaml_util.dump(list(exposed_entities["entities"].values())))
|
||||
|
||||
return prompt
|
||||
|
||||
@ -428,8 +428,9 @@ class AssistAPI(API):
|
||||
exposed_domains: set[str] | None = None
|
||||
if exposed_entities is not None:
|
||||
exposed_domains = {
|
||||
split_entity_id(entity_id)[0] for entity_id in exposed_entities
|
||||
info["domain"] for info in exposed_entities["entities"].values()
|
||||
}
|
||||
|
||||
intent_handlers = [
|
||||
intent_handler
|
||||
for intent_handler in intent_handlers
|
||||
@ -441,25 +442,29 @@ class AssistAPI(API):
|
||||
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
||||
for intent_handler in intent_handlers
|
||||
]
|
||||
if exposed_domains and CALENDAR_DOMAIN in exposed_domains:
|
||||
tools.append(CalendarGetEventsTool())
|
||||
|
||||
if llm_context.assistant is not None:
|
||||
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
|
||||
if not async_should_expose(
|
||||
self.hass, llm_context.assistant, state.entity_id
|
||||
):
|
||||
continue
|
||||
if exposed_entities:
|
||||
if exposed_entities[CALENDAR_DOMAIN]:
|
||||
names = []
|
||||
for info in exposed_entities[CALENDAR_DOMAIN].values():
|
||||
names.extend(info["names"].split(", "))
|
||||
tools.append(CalendarGetEventsTool(names))
|
||||
|
||||
tools.append(ScriptTool(self.hass, state.entity_id))
|
||||
tools.extend(
|
||||
ScriptTool(self.hass, script_entity_id)
|
||||
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _get_exposed_entities(
|
||||
hass: HomeAssistant, assistant: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get exposed entities."""
|
||||
) -> dict[str, dict[str, dict[str, Any]]]:
|
||||
"""Get exposed entities.
|
||||
|
||||
Splits out calendars and scripts.
|
||||
"""
|
||||
area_registry = ar.async_get(hass)
|
||||
entity_registry = er.async_get(hass)
|
||||
device_registry = dr.async_get(hass)
|
||||
@ -480,12 +485,13 @@ def _get_exposed_entities(
|
||||
}
|
||||
|
||||
entities = {}
|
||||
data: dict[str, dict[str, Any]] = {
|
||||
SCRIPT_DOMAIN: {},
|
||||
CALENDAR_DOMAIN: {},
|
||||
}
|
||||
|
||||
for state in hass.states.async_all():
|
||||
if (
|
||||
not async_should_expose(hass, assistant, state.entity_id)
|
||||
or state.domain == SCRIPT_DOMAIN
|
||||
):
|
||||
if not async_should_expose(hass, assistant, state.entity_id):
|
||||
continue
|
||||
|
||||
description: str | None = None
|
||||
@ -532,9 +538,13 @@ def _get_exposed_entities(
|
||||
}:
|
||||
info["attributes"] = attributes
|
||||
|
||||
entities[state.entity_id] = info
|
||||
if state.domain in data:
|
||||
data[state.domain][state.entity_id] = info
|
||||
else:
|
||||
entities[state.entity_id] = info
|
||||
|
||||
return entities
|
||||
data["entities"] = entities
|
||||
return data
|
||||
|
||||
|
||||
def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
||||
@ -816,15 +826,18 @@ class CalendarGetEventsTool(Tool):
|
||||
name = "calendar_get_events"
|
||||
description = (
|
||||
"Get events from a calendar. "
|
||||
"When asked when something happens, search the whole week. "
|
||||
"When asked if something happens, search the whole week. "
|
||||
"Results are RFC 5545 which means 'end' is exclusive."
|
||||
)
|
||||
parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("calendar"): cv.string,
|
||||
vol.Required("range"): vol.In(["today", "week"]),
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, calendars: list[str]) -> None:
|
||||
"""Init the get events tool."""
|
||||
self.parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("calendar"): vol.In(calendars),
|
||||
vol.Required("range"): vol.In(["today", "week"]),
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
|
@ -1170,7 +1170,9 @@ async def test_selector_serializer(
|
||||
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
"""Test the calendar get events tool."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
hass.states.async_set("calendar.test_calendar", "on", {"friendly_name": "Test"})
|
||||
hass.states.async_set(
|
||||
"calendar.test_calendar", "on", {"friendly_name": "Mock Calendar Name"}
|
||||
)
|
||||
async_expose_entity(hass, "conversation", "calendar.test_calendar", True)
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
@ -1182,7 +1184,11 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
device_id=None,
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
assert [tool for tool in api.tools if tool.name == "calendar_get_events"]
|
||||
tool = next(
|
||||
(tool for tool in api.tools if tool.name == "calendar_get_events"), None
|
||||
)
|
||||
assert tool is not None
|
||||
assert tool.parameters.schema["calendar"].container == ["Mock Calendar Name"]
|
||||
|
||||
calls = async_mock_service(
|
||||
hass,
|
||||
@ -1212,7 +1218,10 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name="calendar_get_events",
|
||||
tool_args={"calendar": "calendar.test_calendar", "range": "today"},
|
||||
tool_args={
|
||||
"calendar": "Mock Calendar Name",
|
||||
"range": "today",
|
||||
},
|
||||
)
|
||||
now = dt_util.now()
|
||||
with patch("homeassistant.util.dt.now", return_value=now):
|
||||
|
Loading…
x
Reference in New Issue
Block a user