Include script description in LLM exposed entities (#118749)

* Include script description in LLM exposed entities

* Fix race in test

* Fix type

* Expose script

* Remove fields
This commit is contained in:
Paulus Schoutsen 2024-06-04 11:18:07 -04:00 committed by GitHub
parent 8610436948
commit 52ad90a68d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 0 deletions

View File

@ -29,6 +29,7 @@ from . import (
entity_registry as er, entity_registry as er,
floor_registry as fr, floor_registry as fr,
intent, intent,
service,
) )
from .singleton import singleton from .singleton import singleton
@ -407,6 +408,7 @@ def _get_exposed_entities(
entity_entry = entity_registry.async_get(state.entity_id) entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name] names = [state.name]
area_names = [] area_names = []
description: str | None = None
if entity_entry is not None: if entity_entry is not None:
names.extend(entity_entry.aliases) names.extend(entity_entry.aliases)
@ -426,11 +428,25 @@ def _get_exposed_entities(
area_names.append(area.name) area_names.append(area.name)
area_names.extend(area.aliases) area_names.extend(area.aliases)
if (
state.domain == "script"
and entity_entry.unique_id
and (
service_desc := service.async_get_cached_service_description(
hass, "script", entity_entry.unique_id
)
)
):
description = service_desc.get("description")
info: dict[str, Any] = { info: dict[str, Any] = {
"names": ", ".join(names), "names": ", ".join(names),
"state": state.state, "state": state.state,
} }
if description:
info["description"] = description
if area_names: if area_names:
info["areas"] = ", ".join(area_names) info["areas"] = ", ".join(area_names)

View File

@ -655,6 +655,14 @@ def _load_services_files(
return [_load_services_file(hass, integration) for integration in integrations] return [_load_services_file(hass, integration) for integration in integrations]
@callback
def async_get_cached_service_description(
hass: HomeAssistant, domain: str, service: str
) -> dict[str, Any] | None:
"""Return the cached description for a service."""
return hass.data.get(SERVICE_DESCRIPTION_CACHE, {}).get((domain, service))
@bind_hass @bind_hass
async def async_get_all_descriptions( async def async_get_all_descriptions(
hass: HomeAssistant, hass: HomeAssistant,

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.intent import async_register_timer_handler from homeassistant.components.intent import async_register_timer_handler
from homeassistant.core import Context, HomeAssistant, State from homeassistant.core import Context, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -293,6 +294,26 @@ async def test_assist_api_prompt(
) )
# Expose entities # Expose entities
# Create a script with a unique ID
assert await async_setup_component(
hass,
"script",
{
"script": {
"test_script": {
"description": "This is a test script",
"sequence": [],
"fields": {
"beer": {"description": "Number of beers"},
"wine": {},
},
}
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
entry.add_to_hass(hass) entry.add_to_hass(hass)
device = device_registry.async_get_or_create( device = device_registry.async_get_or_create(
@ -471,6 +492,11 @@ async def test_assist_api_prompt(
"names": "Unnamed Device", "names": "Unnamed Device",
"state": "unavailable", "state": "unavailable",
}, },
"script.test_script": {
"description": "This is a test script",
"names": "test_script",
"state": "off",
},
} }
exposed_entities_prompt = ( exposed_entities_prompt = (
"An overview of the areas and the devices in this smart home:\n" "An overview of the areas and the devices in this smart home:\n"