mirror of
https://github.com/home-assistant/core.git
synced 2025-07-08 22:07:10 +00:00
parent
7497beefed
commit
85540cea3f
@ -49,9 +49,9 @@ from . import (
|
|||||||
)
|
)
|
||||||
from .singleton import singleton
|
from .singleton import singleton
|
||||||
|
|
||||||
SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey(
|
ACTION_PARAMETERS_CACHE: HassKey[
|
||||||
"llm_script_parameters_cache"
|
dict[str, dict[str, tuple[str | None, vol.Schema]]]
|
||||||
)
|
] = HassKey("llm_action_parameters_cache")
|
||||||
|
|
||||||
|
|
||||||
LLM_API_ASSIST = "assist"
|
LLM_API_ASSIST = "assist"
|
||||||
@ -624,29 +624,27 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
|||||||
return {"type": "string"}
|
return {"type": "string"}
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_script_parameters(
|
def _get_cached_action_parameters(
|
||||||
hass: HomeAssistant, entity_id: str
|
hass: HomeAssistant, domain: str, action: str
|
||||||
) -> tuple[str | None, vol.Schema]:
|
) -> tuple[str | None, vol.Schema]:
|
||||||
"""Get script description and schema."""
|
"""Get action description and schema."""
|
||||||
entity_registry = er.async_get(hass)
|
|
||||||
|
|
||||||
description = None
|
description = None
|
||||||
parameters = vol.Schema({})
|
parameters = vol.Schema({})
|
||||||
entity_entry = entity_registry.async_get(entity_id)
|
|
||||||
if entity_entry and entity_entry.unique_id:
|
parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE)
|
||||||
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
|
|
||||||
|
|
||||||
if parameters_cache is None:
|
if parameters_cache is None:
|
||||||
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
|
parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def clear_cache(event: Event) -> None:
|
def clear_cache(event: Event) -> None:
|
||||||
"""Clear script parameter cache on script reload or delete."""
|
"""Clear action parameter cache on action removal."""
|
||||||
if (
|
if (
|
||||||
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
|
event.data[ATTR_DOMAIN] in parameters_cache
|
||||||
and event.data[ATTR_SERVICE] in parameters_cache
|
and event.data[ATTR_SERVICE]
|
||||||
|
in parameters_cache[event.data[ATTR_DOMAIN]]
|
||||||
):
|
):
|
||||||
parameters_cache.pop(event.data[ATTR_SERVICE])
|
parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE])
|
||||||
|
|
||||||
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
||||||
|
|
||||||
@ -655,19 +653,17 @@ def _get_cached_script_parameters(
|
|||||||
"""Cleanup."""
|
"""Cleanup."""
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
hass.bus.async_listen_once(
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close)
|
||||||
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
|
|
||||||
)
|
|
||||||
|
|
||||||
if entity_entry.unique_id in parameters_cache:
|
if domain in parameters_cache and action in parameters_cache[domain]:
|
||||||
return parameters_cache[entity_entry.unique_id]
|
return parameters_cache[domain][action]
|
||||||
|
|
||||||
if service_desc := service.async_get_cached_service_description(
|
if action_desc := service.async_get_cached_service_description(
|
||||||
hass, SCRIPT_DOMAIN, entity_entry.unique_id
|
hass, domain, action
|
||||||
):
|
):
|
||||||
description = service_desc.get("description")
|
description = action_desc.get("description")
|
||||||
schema: dict[vol.Marker, Any] = {}
|
schema: dict[vol.Marker, Any] = {}
|
||||||
fields = service_desc.get("fields", {})
|
fields = action_desc.get("fields", {})
|
||||||
|
|
||||||
for field, config in fields.items():
|
for field, config in fields.items():
|
||||||
field_description = config.get("description")
|
field_description = config.get("description")
|
||||||
@ -685,6 +681,11 @@ def _get_cached_script_parameters(
|
|||||||
|
|
||||||
parameters = vol.Schema(schema)
|
parameters = vol.Schema(schema)
|
||||||
|
|
||||||
|
if domain == SCRIPT_DOMAIN:
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
if (
|
||||||
|
entity_id := entity_registry.async_get_entity_id(domain, domain, action)
|
||||||
|
) and (entity_entry := entity_registry.async_get(entity_id)):
|
||||||
aliases: list[str] = []
|
aliases: list[str] = []
|
||||||
if entity_entry.name:
|
if entity_entry.name:
|
||||||
aliases.append(entity_entry.name)
|
aliases.append(entity_entry.name)
|
||||||
@ -696,32 +697,32 @@ def _get_cached_script_parameters(
|
|||||||
else:
|
else:
|
||||||
description = "Aliases: " + str(list(aliases))
|
description = "Aliases: " + str(list(aliases))
|
||||||
|
|
||||||
parameters_cache[entity_entry.unique_id] = (description, parameters)
|
parameters_cache.setdefault(domain, {})[action] = (description, parameters)
|
||||||
|
|
||||||
return description, parameters
|
return description, parameters
|
||||||
|
|
||||||
|
|
||||||
class ScriptTool(Tool):
|
class ActionTool(Tool):
|
||||||
"""LLM Tool representing a Script."""
|
"""LLM Tool representing an action."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
script_entity_id: str,
|
domain: str,
|
||||||
|
action: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Init the class."""
|
"""Init the class."""
|
||||||
self._object_id = self.name = split_entity_id(script_entity_id)[1]
|
self._domain = domain
|
||||||
if self.name[0].isdigit():
|
self._action = action
|
||||||
self.name = "_" + self.name
|
self.name = f"{domain}.{action}"
|
||||||
|
self.description, self.parameters = _get_cached_action_parameters(
|
||||||
self.description, self.parameters = _get_cached_script_parameters(
|
hass, domain, action
|
||||||
hass, script_entity_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_call(
|
async def async_call(
|
||||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
) -> JsonObjectType:
|
) -> JsonObjectType:
|
||||||
"""Run the script."""
|
"""Call the action."""
|
||||||
|
|
||||||
for field, validator in self.parameters.schema.items():
|
for field, validator in self.parameters.schema.items():
|
||||||
if field not in tool_input.tool_args:
|
if field not in tool_input.tool_args:
|
||||||
@ -753,8 +754,8 @@ class ScriptTool(Tool):
|
|||||||
tool_input.tool_args[field] = floor
|
tool_input.tool_args[field] = floor
|
||||||
|
|
||||||
result = await hass.services.async_call(
|
result = await hass.services.async_call(
|
||||||
SCRIPT_DOMAIN,
|
self._domain,
|
||||||
self._object_id,
|
self._action,
|
||||||
tool_input.tool_args,
|
tool_input.tool_args,
|
||||||
context=llm_context.context,
|
context=llm_context.context,
|
||||||
blocking=True,
|
blocking=True,
|
||||||
@ -764,6 +765,30 @@ class ScriptTool(Tool):
|
|||||||
return {"success": True, "result": result}
|
return {"success": True, "result": result}
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptTool(ActionTool):
|
||||||
|
"""LLM Tool representing a Script."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
script_entity_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Init the class."""
|
||||||
|
script_name = split_entity_id(script_entity_id)[1]
|
||||||
|
|
||||||
|
action = script_name
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
entity_entry = entity_registry.async_get(script_entity_id)
|
||||||
|
if entity_entry and entity_entry.unique_id:
|
||||||
|
action = entity_entry.unique_id
|
||||||
|
|
||||||
|
super().__init__(hass, SCRIPT_DOMAIN, action)
|
||||||
|
|
||||||
|
self.name = script_name
|
||||||
|
if self.name[0].isdigit():
|
||||||
|
self.name = "_" + self.name
|
||||||
|
|
||||||
|
|
||||||
class CalendarGetEventsTool(Tool):
|
class CalendarGetEventsTool(Tool):
|
||||||
"""LLM Tool allowing querying a calendar."""
|
"""LLM Tool allowing querying a calendar."""
|
||||||
|
|
||||||
|
@ -745,7 +745,7 @@ async def test_script_tool(
|
|||||||
area = area_registry.async_create("Living room")
|
area = area_registry.async_create("Living room")
|
||||||
floor = floor_registry.async_create("2")
|
floor = floor_registry.async_create("2")
|
||||||
|
|
||||||
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
|
assert llm.ACTION_PARAMETERS_CACHE not in hass.data
|
||||||
|
|
||||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
|
||||||
@ -769,7 +769,7 @@ async def test_script_tool(
|
|||||||
}
|
}
|
||||||
assert tool.parameters.schema == schema
|
assert tool.parameters.schema == schema
|
||||||
|
|
||||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
|
||||||
"test_script": (
|
"test_script": (
|
||||||
"This is a test script. Aliases: ['script name', 'script alias']",
|
"This is a test script. Aliases: ['script name', 'script alias']",
|
||||||
vol.Schema(schema),
|
vol.Schema(schema),
|
||||||
@ -866,7 +866,7 @@ async def test_script_tool(
|
|||||||
):
|
):
|
||||||
await hass.services.async_call("script", "reload", blocking=True)
|
await hass.services.async_call("script", "reload", blocking=True)
|
||||||
|
|
||||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
|
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {}
|
||||||
|
|
||||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
|
|
||||||
@ -882,7 +882,7 @@ async def test_script_tool(
|
|||||||
schema = {vol.Required("beer", description="Number of beers"): cv.string}
|
schema = {vol.Required("beer", description="Number of beers"): cv.string}
|
||||||
assert tool.parameters.schema == schema
|
assert tool.parameters.schema == schema
|
||||||
|
|
||||||
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
|
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
|
||||||
"test_script": (
|
"test_script": (
|
||||||
"This is a new test script. Aliases: ['script name', 'script alias']",
|
"This is a new test script. Aliases: ['script name', 'script alias']",
|
||||||
vol.Schema(schema),
|
vol.Schema(schema),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user