diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 8b2e0660687..768152c314f 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -420,7 +420,9 @@ class AssistAPI(API): ): continue - tools.append(ScriptTool(self.hass, state.entity_id)) + script_tool = ScriptTool(self.hass, state.entity_id) + if script_tool.parameters.schema: + tools.append(script_tool) return tools @@ -451,12 +453,17 @@ def _get_exposed_entities( entities = {} for state in hass.states.async_all(): - if state.domain == SCRIPT_DOMAIN: - continue - if not async_should_expose(hass, assistant, state.entity_id): continue + description: str | None = None + if state.domain == SCRIPT_DOMAIN: + description, parameters = _get_cached_script_parameters( + hass, state.entity_id + ) + if parameters.schema: # Only list scripts without input fields here + continue + entity_entry = entity_registry.async_get(state.entity_id) names = [state.name] area_names = [] @@ -485,6 +492,9 @@ def _get_exposed_entities( "state": state.state, } + if description: + info["description"] = description + if area_names: info["areas"] = ", ".join(area_names) @@ -610,6 +620,83 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901 return {"type": "string"} +def _get_cached_script_parameters( + hass: HomeAssistant, entity_id: str +) -> tuple[str | None, vol.Schema]: + """Get script description and schema.""" + entity_registry = er.async_get(hass) + + description = None + parameters = vol.Schema({}) + entity_entry = entity_registry.async_get(entity_id) + if entity_entry and entity_entry.unique_id: + parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE) + + if parameters_cache is None: + parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {} + + @callback + def clear_cache(event: Event) -> None: + """Clear script parameter cache on script reload or delete.""" + if ( + event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN + and event.data[ATTR_SERVICE] in parameters_cache + ): + parameters_cache.pop(event.data[ATTR_SERVICE]) + + cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache) + + @callback + def on_homeassistant_close(event: Event) -> None: + """Cleanup.""" + cancel() + + hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close + ) + + if entity_entry.unique_id in parameters_cache: + return parameters_cache[entity_entry.unique_id] + + if service_desc := service.async_get_cached_service_description( + hass, SCRIPT_DOMAIN, entity_entry.unique_id + ): + description = service_desc.get("description") + schema: dict[vol.Marker, Any] = {} + fields = service_desc.get("fields", {}) + + for field, config in fields.items(): + field_description = config.get("description") + if not field_description: + field_description = config.get("name") + key: vol.Marker + if config.get("required"): + key = vol.Required(field, description=field_description) + else: + key = vol.Optional(field, description=field_description) + if "selector" in config: + schema[key] = selector.selector(config["selector"]) + else: + schema[key] = cv.string + + parameters = vol.Schema(schema) + + aliases: list[str] = [] + if entity_entry.name: + aliases.append(entity_entry.name) + if entity_entry.aliases: + aliases.extend(entity_entry.aliases) + if aliases: + if description: + description = description + ". Aliases: " + str(list(aliases)) + else: + description = "Aliases: " + str(list(aliases)) + + parameters_cache[entity_entry.unique_id] = (description, parameters) + + return description, parameters + + class ScriptTool(Tool): """LLM Tool representing a Script.""" @@ -619,86 +706,14 @@ class ScriptTool(Tool): script_entity_id: str, ) -> None: """Init the class.""" - entity_registry = er.async_get(hass) - self.name = split_entity_id(script_entity_id)[1] if self.name[0].isdigit(): self.name = "_" + self.name self._entity_id = script_entity_id - self.parameters = vol.Schema({}) - entity_entry = entity_registry.async_get(script_entity_id) - if entity_entry and entity_entry.unique_id: - parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE) - if parameters_cache is None: - parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {} - - @callback - def clear_cache(event: Event) -> None: - """Clear script parameter cache on script reload or delete.""" - if ( - event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN - and event.data[ATTR_SERVICE] in parameters_cache - ): - parameters_cache.pop(event.data[ATTR_SERVICE]) - - cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache) - - @callback - def on_homeassistant_close(event: Event) -> None: - """Cleanup.""" - cancel() - - hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close - ) - - if entity_entry.unique_id in parameters_cache: - self.description, self.parameters = parameters_cache[ - entity_entry.unique_id - ] - return - - if service_desc := service.async_get_cached_service_description( - hass, SCRIPT_DOMAIN, entity_entry.unique_id - ): - self.description = service_desc.get("description") - schema: dict[vol.Marker, Any] = {} - fields = service_desc.get("fields", {}) - - for field, config in fields.items(): - description = config.get("description") - if not description: - description = config.get("name") - key: vol.Marker - if config.get("required"): - key = vol.Required(field, description=description) - else: - key = vol.Optional(field, description=description) - if "selector" in config: - schema[key] = selector.selector(config["selector"]) - else: - schema[key] = cv.string - - self.parameters = vol.Schema(schema) - - aliases: list[str] = [] - if entity_entry.name: - aliases.append(entity_entry.name) - if entity_entry.aliases: - aliases.extend(entity_entry.aliases) - if aliases: - if self.description: - self.description = ( - self.description + ". Aliases: " + str(list(aliases)) - ) - else: - self.description = "Aliases: " + str(list(aliases)) - - parameters_cache[entity_entry.unique_id] = ( - self.description, - self.parameters, - ) + self.description, self.parameters = _get_cached_script_parameters( + hass, script_entity_id + ) async def async_call( self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 4d14abb9819..cd36fe18933 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -374,11 +374,16 @@ async def test_assist_api_prompt( "beer": {"description": "Number of beers"}, "wine": {}, }, - } + }, + "script_with_no_fields": { + "description": "This is another test script", + "sequence": [], + }, } }, ) async_expose_entity(hass, "conversation", "script.test_script", True) + async_expose_entity(hass, "conversation", "script.script_with_no_fields", True) entry = MockConfigEntry(title=None) entry.add_to_hass(hass) @@ -511,6 +516,10 @@ async def test_assist_api_prompt( ) ) exposed_entities_prompt = """An overview of the areas and the devices in this smart home: +- names: script_with_no_fields + domain: script + state: 'off' + description: This is another test script - names: Kitchen domain: light state: 'on' @@ -657,6 +666,10 @@ async def test_script_tool( "extra_field": {"selector": {"area": {}}}, }, }, + "script_with_no_fields": { + "description": "This is another test script", + "sequence": [], + }, "unexposed_script": { "sequence": [], }, @@ -664,6 +677,7 @@ async def test_script_tool( }, ) async_expose_entity(hass, "conversation", "script.test_script", True) + async_expose_entity(hass, "conversation", "script.script_with_no_fields", True) entity_registry.async_update_entity( "script.test_script", name="script name", aliases={"script alias"} @@ -700,7 +714,8 @@ async def test_script_tool( "test_script": ( "This is a test script. Aliases: ['script name', 'script alias']", vol.Schema(schema), - ) + ), + "script_with_no_fields": ("This is another test script", vol.Schema({})), } tool_input = llm.ToolInput( @@ -781,7 +796,8 @@ async def test_script_tool( "test_script": ( "This is a new test script. Aliases: ['script name', 'script alias']", vol.Schema(schema), - ) + ), + "script_with_no_fields": ("This is another test script", vol.Schema({})), }