From 52ad90a68d432a30cfac08c37b886576b06bb884 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 4 Jun 2024 11:18:07 -0400 Subject: [PATCH] Include script description in LLM exposed entities (#118749) * Include script description in LLM exposed entities * Fix race in test * Fix type * Expose script * Remove fields --- homeassistant/helpers/llm.py | 16 ++++++++++++++++ homeassistant/helpers/service.py | 8 ++++++++ tests/helpers/test_llm.py | 26 ++++++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 31e3c791630..3c240692d52 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -29,6 +29,7 @@ from . import ( entity_registry as er, floor_registry as fr, intent, + service, ) from .singleton import singleton @@ -407,6 +408,7 @@ def _get_exposed_entities( entity_entry = entity_registry.async_get(state.entity_id) names = [state.name] area_names = [] + description: str | None = None if entity_entry is not None: names.extend(entity_entry.aliases) @@ -426,11 +428,25 @@ def _get_exposed_entities( area_names.append(area.name) area_names.extend(area.aliases) + if ( + state.domain == "script" + and entity_entry.unique_id + and ( + service_desc := service.async_get_cached_service_description( + hass, "script", entity_entry.unique_id + ) + ) + ): + description = service_desc.get("description") + info: dict[str, Any] = { "names": ", ".join(names), "state": state.state, } + if description: + info["description"] = description + if area_names: info["areas"] = ", ".join(area_names) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index d20cba8909f..3a828ada9c2 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -655,6 +655,14 @@ def _load_services_files( return [_load_services_file(hass, integration) for integration in integrations] +@callback +def async_get_cached_service_description( + hass: HomeAssistant, domain: str, service: str +) -> dict[str, Any] | None: + """Return the cached description for a service.""" + return hass.data.get(SERVICE_DESCRIPTION_CACHE, {}).get((domain, service)) + + @bind_hass async def async_get_all_descriptions( hass: HomeAssistant, diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 6c9451bc843..3f61ed8a0ed 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest import voluptuous as vol +from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.intent import async_register_timer_handler from homeassistant.core import Context, HomeAssistant, State from homeassistant.exceptions import HomeAssistantError @@ -293,6 +294,26 @@ async def test_assist_api_prompt( ) # Expose entities + + # Create a script with a unique ID + assert await async_setup_component( + hass, + "script", + { + "script": { + "test_script": { + "description": "This is a test script", + "sequence": [], + "fields": { + "beer": {"description": "Number of beers"}, + "wine": {}, + }, + } + } + }, + ) + async_expose_entity(hass, "conversation", "script.test_script", True) + entry = MockConfigEntry(title=None) entry.add_to_hass(hass) device = device_registry.async_get_or_create( @@ -471,6 +492,11 @@ async def test_assist_api_prompt( "names": "Unnamed Device", "state": "unavailable", }, + "script.test_script": { + "description": "This is a test script", + "names": "test_script", + "state": "off", + }, } exposed_entities_prompt = ( "An overview of the areas and the devices in this smart home:\n"