Expose scripts with no fields as entities (#123061)

This commit is contained in:
Denis Shulyaka 2024-10-23 09:14:07 +03:00 committed by GitHub
parent 3ddef56167
commit e0e61b5262
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 82 deletions

View File

@ -420,7 +420,9 @@ class AssistAPI(API):
): ):
continue continue
tools.append(ScriptTool(self.hass, state.entity_id)) script_tool = ScriptTool(self.hass, state.entity_id)
if script_tool.parameters.schema:
tools.append(script_tool)
return tools return tools
@ -451,10 +453,15 @@ def _get_exposed_entities(
entities = {} entities = {}
for state in hass.states.async_all(): for state in hass.states.async_all():
if state.domain == SCRIPT_DOMAIN: if not async_should_expose(hass, assistant, state.entity_id):
continue continue
if not async_should_expose(hass, assistant, state.entity_id): description: str | None = None
if state.domain == SCRIPT_DOMAIN:
description, parameters = _get_cached_script_parameters(
hass, state.entity_id
)
if parameters.schema: # Only list scripts without input fields here
continue continue
entity_entry = entity_registry.async_get(state.entity_id) entity_entry = entity_registry.async_get(state.entity_id)
@ -485,6 +492,9 @@ def _get_exposed_entities(
"state": state.state, "state": state.state,
} }
if description:
info["description"] = description
if area_names: if area_names:
info["areas"] = ", ".join(area_names) info["areas"] = ", ".join(area_names)
@ -610,23 +620,15 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
return {"type": "string"} return {"type": "string"}
class ScriptTool(Tool): def _get_cached_script_parameters(
"""LLM Tool representing a Script.""" hass: HomeAssistant, entity_id: str
) -> tuple[str | None, vol.Schema]:
def __init__( """Get script description and schema."""
self,
hass: HomeAssistant,
script_entity_id: str,
) -> None:
"""Init the class."""
entity_registry = er.async_get(hass) entity_registry = er.async_get(hass)
self.name = split_entity_id(script_entity_id)[1] description = None
if self.name[0].isdigit(): parameters = vol.Schema({})
self.name = "_" + self.name entity_entry = entity_registry.async_get(entity_id)
self._entity_id = script_entity_id
self.parameters = vol.Schema({})
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id: if entity_entry and entity_entry.unique_id:
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE) parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
@ -654,33 +656,30 @@ class ScriptTool(Tool):
) )
if entity_entry.unique_id in parameters_cache: if entity_entry.unique_id in parameters_cache:
self.description, self.parameters = parameters_cache[ return parameters_cache[entity_entry.unique_id]
entity_entry.unique_id
]
return
if service_desc := service.async_get_cached_service_description( if service_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id hass, SCRIPT_DOMAIN, entity_entry.unique_id
): ):
self.description = service_desc.get("description") description = service_desc.get("description")
schema: dict[vol.Marker, Any] = {} schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {}) fields = service_desc.get("fields", {})
for field, config in fields.items(): for field, config in fields.items():
description = config.get("description") field_description = config.get("description")
if not description: if not field_description:
description = config.get("name") field_description = config.get("name")
key: vol.Marker key: vol.Marker
if config.get("required"): if config.get("required"):
key = vol.Required(field, description=description) key = vol.Required(field, description=field_description)
else: else:
key = vol.Optional(field, description=description) key = vol.Optional(field, description=field_description)
if "selector" in config: if "selector" in config:
schema[key] = selector.selector(config["selector"]) schema[key] = selector.selector(config["selector"])
else: else:
schema[key] = cv.string schema[key] = cv.string
self.parameters = vol.Schema(schema) parameters = vol.Schema(schema)
aliases: list[str] = [] aliases: list[str] = []
if entity_entry.name: if entity_entry.name:
@ -688,16 +687,32 @@ class ScriptTool(Tool):
if entity_entry.aliases: if entity_entry.aliases:
aliases.extend(entity_entry.aliases) aliases.extend(entity_entry.aliases)
if aliases: if aliases:
if self.description: if description:
self.description = ( description = description + ". Aliases: " + str(list(aliases))
self.description + ". Aliases: " + str(list(aliases))
)
else: else:
self.description = "Aliases: " + str(list(aliases)) description = "Aliases: " + str(list(aliases))
parameters_cache[entity_entry.unique_id] = ( parameters_cache[entity_entry.unique_id] = (description, parameters)
self.description,
self.parameters, return description, parameters
class ScriptTool(Tool):
"""LLM Tool representing a Script."""
def __init__(
self,
hass: HomeAssistant,
script_entity_id: str,
) -> None:
"""Init the class."""
self.name = split_entity_id(script_entity_id)[1]
if self.name[0].isdigit():
self.name = "_" + self.name
self._entity_id = script_entity_id
self.description, self.parameters = _get_cached_script_parameters(
hass, script_entity_id
) )
async def async_call( async def async_call(

View File

@ -374,11 +374,16 @@ async def test_assist_api_prompt(
"beer": {"description": "Number of beers"}, "beer": {"description": "Number of beers"},
"wine": {}, "wine": {},
}, },
} },
"script_with_no_fields": {
"description": "This is another test script",
"sequence": [],
},
} }
}, },
) )
async_expose_entity(hass, "conversation", "script.test_script", True) async_expose_entity(hass, "conversation", "script.test_script", True)
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -511,6 +516,10 @@ async def test_assist_api_prompt(
) )
) )
exposed_entities_prompt = """An overview of the areas and the devices in this smart home: exposed_entities_prompt = """An overview of the areas and the devices in this smart home:
- names: script_with_no_fields
domain: script
state: 'off'
description: This is another test script
- names: Kitchen - names: Kitchen
domain: light domain: light
state: 'on' state: 'on'
@ -657,6 +666,10 @@ async def test_script_tool(
"extra_field": {"selector": {"area": {}}}, "extra_field": {"selector": {"area": {}}},
}, },
}, },
"script_with_no_fields": {
"description": "This is another test script",
"sequence": [],
},
"unexposed_script": { "unexposed_script": {
"sequence": [], "sequence": [],
}, },
@ -664,6 +677,7 @@ async def test_script_tool(
}, },
) )
async_expose_entity(hass, "conversation", "script.test_script", True) async_expose_entity(hass, "conversation", "script.test_script", True)
async_expose_entity(hass, "conversation", "script.script_with_no_fields", True)
entity_registry.async_update_entity( entity_registry.async_update_entity(
"script.test_script", name="script name", aliases={"script alias"} "script.test_script", name="script name", aliases={"script alias"}
@ -700,7 +714,8 @@ async def test_script_tool(
"test_script": ( "test_script": (
"This is a test script. Aliases: ['script name', 'script alias']", "This is a test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema), vol.Schema(schema),
) ),
"script_with_no_fields": ("This is another test script", vol.Schema({})),
} }
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
@ -781,7 +796,8 @@ async def test_script_tool(
"test_script": ( "test_script": (
"This is a new test script. Aliases: ['script name', 'script alias']", "This is a new test script. Aliases: ['script name', 'script alias']",
vol.Schema(schema), vol.Schema(schema),
) ),
"script_with_no_fields": ("This is another test script", vol.Schema({})),
} }