mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Simplify llm calendar tool (#137402)
* Simplify calendar tool * Clean up exposed entities
This commit is contained in:
parent
79563f3746
commit
c506c9080a
@ -35,13 +35,13 @@ class StatelessAssistAPI(llm.AssistAPI):
|
|||||||
"""Return the prompt for the exposed entities."""
|
"""Return the prompt for the exposed entities."""
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if exposed_entities:
|
if exposed_entities and exposed_entities["entities"]:
|
||||||
prompt.append(
|
prompt.append(
|
||||||
"An overview of the areas and the devices in this smart home:"
|
"An overview of the areas and the devices in this smart home:"
|
||||||
)
|
)
|
||||||
entities = [
|
entities = [
|
||||||
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
|
{k: v for k, v in entity_info.items() if k in EXPOSED_ENTITY_FIELDS}
|
||||||
for entity_info in exposed_entities.values()
|
for entity_info in exposed_entities["entities"].values()
|
||||||
]
|
]
|
||||||
prompt.append(yaml_util.dump(list(entities)))
|
prompt.append(yaml_util.dump(list(entities)))
|
||||||
|
|
||||||
|
@ -326,7 +326,7 @@ class AssistAPI(API):
|
|||||||
def _async_get_api_prompt(
|
def _async_get_api_prompt(
|
||||||
self, llm_context: LLMContext, exposed_entities: dict | None
|
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||||
) -> str:
|
) -> str:
|
||||||
if not exposed_entities:
|
if not exposed_entities or not exposed_entities["entities"]:
|
||||||
return (
|
return (
|
||||||
"Only if the user wants to control a device, tell them to expose entities "
|
"Only if the user wants to control a device, tell them to expose entities "
|
||||||
"to their voice assistant in Home Assistant."
|
"to their voice assistant in Home Assistant."
|
||||||
@ -389,11 +389,11 @@ class AssistAPI(API):
|
|||||||
"""Return the prompt for the API for exposed entities."""
|
"""Return the prompt for the API for exposed entities."""
|
||||||
prompt = []
|
prompt = []
|
||||||
|
|
||||||
if exposed_entities:
|
if exposed_entities and exposed_entities["entities"]:
|
||||||
prompt.append(
|
prompt.append(
|
||||||
"An overview of the areas and the devices in this smart home:"
|
"An overview of the areas and the devices in this smart home:"
|
||||||
)
|
)
|
||||||
prompt.append(yaml_util.dump(list(exposed_entities.values())))
|
prompt.append(yaml_util.dump(list(exposed_entities["entities"].values())))
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
@ -425,8 +425,9 @@ class AssistAPI(API):
|
|||||||
exposed_domains: set[str] | None = None
|
exposed_domains: set[str] | None = None
|
||||||
if exposed_entities is not None:
|
if exposed_entities is not None:
|
||||||
exposed_domains = {
|
exposed_domains = {
|
||||||
split_entity_id(entity_id)[0] for entity_id in exposed_entities
|
info["domain"] for info in exposed_entities["entities"].values()
|
||||||
}
|
}
|
||||||
|
|
||||||
intent_handlers = [
|
intent_handlers = [
|
||||||
intent_handler
|
intent_handler
|
||||||
for intent_handler in intent_handlers
|
for intent_handler in intent_handlers
|
||||||
@ -438,25 +439,29 @@ class AssistAPI(API):
|
|||||||
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
||||||
for intent_handler in intent_handlers
|
for intent_handler in intent_handlers
|
||||||
]
|
]
|
||||||
if exposed_domains and CALENDAR_DOMAIN in exposed_domains:
|
|
||||||
tools.append(CalendarGetEventsTool())
|
|
||||||
|
|
||||||
if llm_context.assistant is not None:
|
if exposed_entities:
|
||||||
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
|
if exposed_entities[CALENDAR_DOMAIN]:
|
||||||
if not async_should_expose(
|
names = []
|
||||||
self.hass, llm_context.assistant, state.entity_id
|
for info in exposed_entities[CALENDAR_DOMAIN].values():
|
||||||
):
|
names.extend(info["names"].split(", "))
|
||||||
continue
|
tools.append(CalendarGetEventsTool(names))
|
||||||
|
|
||||||
tools.append(ScriptTool(self.hass, state.entity_id))
|
tools.extend(
|
||||||
|
ScriptTool(self.hass, script_entity_id)
|
||||||
|
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
|
||||||
|
)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def _get_exposed_entities(
|
def _get_exposed_entities(
|
||||||
hass: HomeAssistant, assistant: str
|
hass: HomeAssistant, assistant: str
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, dict[str, Any]]]:
|
||||||
"""Get exposed entities."""
|
"""Get exposed entities.
|
||||||
|
|
||||||
|
Splits out calendars and scripts.
|
||||||
|
"""
|
||||||
area_registry = ar.async_get(hass)
|
area_registry = ar.async_get(hass)
|
||||||
entity_registry = er.async_get(hass)
|
entity_registry = er.async_get(hass)
|
||||||
device_registry = dr.async_get(hass)
|
device_registry = dr.async_get(hass)
|
||||||
@ -477,12 +482,13 @@ def _get_exposed_entities(
|
|||||||
}
|
}
|
||||||
|
|
||||||
entities = {}
|
entities = {}
|
||||||
|
data: dict[str, dict[str, Any]] = {
|
||||||
|
SCRIPT_DOMAIN: {},
|
||||||
|
CALENDAR_DOMAIN: {},
|
||||||
|
}
|
||||||
|
|
||||||
for state in hass.states.async_all():
|
for state in hass.states.async_all():
|
||||||
if (
|
if not async_should_expose(hass, assistant, state.entity_id):
|
||||||
not async_should_expose(hass, assistant, state.entity_id)
|
|
||||||
or state.domain == SCRIPT_DOMAIN
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
@ -529,9 +535,13 @@ def _get_exposed_entities(
|
|||||||
}:
|
}:
|
||||||
info["attributes"] = attributes
|
info["attributes"] = attributes
|
||||||
|
|
||||||
entities[state.entity_id] = info
|
if state.domain in data:
|
||||||
|
data[state.domain][state.entity_id] = info
|
||||||
|
else:
|
||||||
|
entities[state.entity_id] = info
|
||||||
|
|
||||||
return entities
|
data["entities"] = entities
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
||||||
@ -813,15 +823,18 @@ class CalendarGetEventsTool(Tool):
|
|||||||
name = "calendar_get_events"
|
name = "calendar_get_events"
|
||||||
description = (
|
description = (
|
||||||
"Get events from a calendar. "
|
"Get events from a calendar. "
|
||||||
"When asked when something happens, search the whole week. "
|
"When asked if something happens, search the whole week. "
|
||||||
"Results are RFC 5545 which means 'end' is exclusive."
|
"Results are RFC 5545 which means 'end' is exclusive."
|
||||||
)
|
)
|
||||||
parameters = vol.Schema(
|
|
||||||
{
|
def __init__(self, calendars: list[str]) -> None:
|
||||||
vol.Required("calendar"): cv.string,
|
"""Init the get events tool."""
|
||||||
vol.Required("range"): vol.In(["today", "week"]),
|
self.parameters = vol.Schema(
|
||||||
}
|
{
|
||||||
)
|
vol.Required("calendar"): vol.In(calendars),
|
||||||
|
vol.Required("range"): vol.In(["today", "week"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
async def async_call(
|
async def async_call(
|
||||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
|
@ -1170,7 +1170,9 @@ async def test_selector_serializer(
|
|||||||
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||||
"""Test the calendar get events tool."""
|
"""Test the calendar get events tool."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
hass.states.async_set("calendar.test_calendar", "on", {"friendly_name": "Test"})
|
hass.states.async_set(
|
||||||
|
"calendar.test_calendar", "on", {"friendly_name": "Mock Calendar Name"}
|
||||||
|
)
|
||||||
async_expose_entity(hass, "conversation", "calendar.test_calendar", True)
|
async_expose_entity(hass, "conversation", "calendar.test_calendar", True)
|
||||||
context = Context()
|
context = Context()
|
||||||
llm_context = llm.LLMContext(
|
llm_context = llm.LLMContext(
|
||||||
@ -1182,7 +1184,11 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
|||||||
device_id=None,
|
device_id=None,
|
||||||
)
|
)
|
||||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert [tool for tool in api.tools if tool.name == "calendar_get_events"]
|
tool = next(
|
||||||
|
(tool for tool in api.tools if tool.name == "calendar_get_events"), None
|
||||||
|
)
|
||||||
|
assert tool is not None
|
||||||
|
assert tool.parameters.schema["calendar"].container == ["Mock Calendar Name"]
|
||||||
|
|
||||||
calls = async_mock_service(
|
calls = async_mock_service(
|
||||||
hass,
|
hass,
|
||||||
@ -1212,7 +1218,10 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name="calendar_get_events",
|
tool_name="calendar_get_events",
|
||||||
tool_args={"calendar": "calendar.test_calendar", "range": "today"},
|
tool_args={
|
||||||
|
"calendar": "Mock Calendar Name",
|
||||||
|
"range": "today",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
now = dt_util.now()
|
now = dt_util.now()
|
||||||
with patch("homeassistant.util.dt.now", return_value=now):
|
with patch("homeassistant.util.dt.now", return_value=now):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user