Add LLM ActionTool (#136591)

Add ActionTool
This commit is contained in:
Denis Shulyaka 2025-01-27 22:21:27 +03:00 committed by GitHub
parent 7497beefed
commit 85540cea3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 78 deletions

View File

@ -49,9 +49,9 @@ from . import (
) )
from .singleton import singleton from .singleton import singleton
SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey( ACTION_PARAMETERS_CACHE: HassKey[
"llm_script_parameters_cache" dict[str, dict[str, tuple[str | None, vol.Schema]]]
) ] = HassKey("llm_action_parameters_cache")
LLM_API_ASSIST = "assist" LLM_API_ASSIST = "assist"
@ -624,29 +624,27 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
return {"type": "string"} return {"type": "string"}
def _get_cached_script_parameters( def _get_cached_action_parameters(
hass: HomeAssistant, entity_id: str hass: HomeAssistant, domain: str, action: str
) -> tuple[str | None, vol.Schema]: ) -> tuple[str | None, vol.Schema]:
"""Get script description and schema.""" """Get action description and schema."""
entity_registry = er.async_get(hass)
description = None description = None
parameters = vol.Schema({}) parameters = vol.Schema({})
entity_entry = entity_registry.async_get(entity_id)
if entity_entry and entity_entry.unique_id: parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE)
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
if parameters_cache is None: if parameters_cache is None:
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {} parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {}
@callback @callback
def clear_cache(event: Event) -> None: def clear_cache(event: Event) -> None:
"""Clear script parameter cache on script reload or delete.""" """Clear action parameter cache on action removal."""
if ( if (
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN event.data[ATTR_DOMAIN] in parameters_cache
and event.data[ATTR_SERVICE] in parameters_cache and event.data[ATTR_SERVICE]
in parameters_cache[event.data[ATTR_DOMAIN]]
): ):
parameters_cache.pop(event.data[ATTR_SERVICE]) parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE])
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache) cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
@ -655,19 +653,17 @@ def _get_cached_script_parameters(
"""Cleanup.""" """Cleanup."""
cancel() cancel()
hass.bus.async_listen_once( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close)
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
)
if entity_entry.unique_id in parameters_cache: if domain in parameters_cache and action in parameters_cache[domain]:
return parameters_cache[entity_entry.unique_id] return parameters_cache[domain][action]
if service_desc := service.async_get_cached_service_description( if action_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id hass, domain, action
): ):
description = service_desc.get("description") description = action_desc.get("description")
schema: dict[vol.Marker, Any] = {} schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {}) fields = action_desc.get("fields", {})
for field, config in fields.items(): for field, config in fields.items():
field_description = config.get("description") field_description = config.get("description")
@ -685,6 +681,11 @@ def _get_cached_script_parameters(
parameters = vol.Schema(schema) parameters = vol.Schema(schema)
if domain == SCRIPT_DOMAIN:
entity_registry = er.async_get(hass)
if (
entity_id := entity_registry.async_get_entity_id(domain, domain, action)
) and (entity_entry := entity_registry.async_get(entity_id)):
aliases: list[str] = [] aliases: list[str] = []
if entity_entry.name: if entity_entry.name:
aliases.append(entity_entry.name) aliases.append(entity_entry.name)
@ -696,32 +697,32 @@ def _get_cached_script_parameters(
else: else:
description = "Aliases: " + str(list(aliases)) description = "Aliases: " + str(list(aliases))
parameters_cache[entity_entry.unique_id] = (description, parameters) parameters_cache.setdefault(domain, {})[action] = (description, parameters)
return description, parameters return description, parameters
class ScriptTool(Tool): class ActionTool(Tool):
"""LLM Tool representing a Script.""" """LLM Tool representing an action."""
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
script_entity_id: str, domain: str,
action: str,
) -> None: ) -> None:
"""Init the class.""" """Init the class."""
self._object_id = self.name = split_entity_id(script_entity_id)[1] self._domain = domain
if self.name[0].isdigit(): self._action = action
self.name = "_" + self.name self.name = f"{domain}.{action}"
self.description, self.parameters = _get_cached_action_parameters(
self.description, self.parameters = _get_cached_script_parameters( hass, domain, action
hass, script_entity_id
) )
async def async_call( async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType: ) -> JsonObjectType:
"""Run the script.""" """Call the action."""
for field, validator in self.parameters.schema.items(): for field, validator in self.parameters.schema.items():
if field not in tool_input.tool_args: if field not in tool_input.tool_args:
@ -753,8 +754,8 @@ class ScriptTool(Tool):
tool_input.tool_args[field] = floor tool_input.tool_args[field] = floor
result = await hass.services.async_call( result = await hass.services.async_call(
SCRIPT_DOMAIN, self._domain,
self._object_id, self._action,
tool_input.tool_args, tool_input.tool_args,
context=llm_context.context, context=llm_context.context,
blocking=True, blocking=True,
@ -764,6 +765,30 @@ class ScriptTool(Tool):
return {"success": True, "result": result} return {"success": True, "result": result}
class ScriptTool(ActionTool):
"""LLM Tool representing a Script."""
def __init__(
self,
hass: HomeAssistant,
script_entity_id: str,
) -> None:
"""Init the class."""
script_name = split_entity_id(script_entity_id)[1]
action = script_name
entity_registry = er.async_get(hass)
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id:
action = entity_entry.unique_id
super().__init__(hass, SCRIPT_DOMAIN, action)
self.name = script_name
if self.name[0].isdigit():
self.name = "_" + self.name
class CalendarGetEventsTool(Tool): class CalendarGetEventsTool(Tool):
"""LLM Tool allowing querying a calendar.""" """LLM Tool allowing querying a calendar."""

View File

@ -745,7 +745,7 @@ async def test_script_tool(
area = area_registry.async_create("Living room") area = area_registry.async_create("Living room")
floor = floor_registry.async_create("2") floor = floor_registry.async_create("2")
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data assert llm.ACTION_PARAMETERS_CACHE not in hass.data
api = await llm.async_get_api(hass, "assist", llm_context) api = await llm.async_get_api(hass, "assist", llm_context)
@ -769,7 +769,7 @@ async def test_script_tool(
} }
assert tool.parameters.schema == schema assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == { assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
"test_script": ( "test_script": (
"This is a test script. Aliases: ['script name', 'script alias']", "This is a test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema), vol.Schema(schema),
@ -866,7 +866,7 @@ async def test_script_tool(
): ):
await hass.services.async_call("script", "reload", blocking=True) await hass.services.async_call("script", "reload", blocking=True)
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {} assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {}
api = await llm.async_get_api(hass, "assist", llm_context) api = await llm.async_get_api(hass, "assist", llm_context)
@ -882,7 +882,7 @@ async def test_script_tool(
schema = {vol.Required("beer", description="Number of beers"): cv.string} schema = {vol.Required("beer", description="Number of beers"): cv.string}
assert tool.parameters.schema == schema assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == { assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
"test_script": ( "test_script": (
"This is a new test script. Aliases: ['script name', 'script alias']", "This is a new test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema), vol.Schema(schema),