Add exposed entities to the Assist LLM API prompt (#118203)

* Add exposed entities to the Assist LLM API prompt

* Check expose entities in Google test

* Copy Google default prompt test cases to LLM tests
This commit is contained in:
Paulus Schoutsen
2024-05-27 00:27:08 -04:00
committed by GitHub
parent c391d73fec
commit ecb05989ca
4 changed files with 526 additions and 88 deletions

View File

@@ -3,7 +3,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, replace
from enum import Enum
from typing import Any
import voluptuous as vol
@@ -13,12 +14,20 @@ from homeassistant.components.conversation.trace import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml
from homeassistant.util.json import JsonObjectType
from . import area_registry, device_registry, floor_registry, intent
from . import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
intent,
)
from .singleton import singleton
LLM_API_ASSIST = "assist"
@@ -140,19 +149,16 @@ class API(ABC):
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
_tool_input = ToolInput(
tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args),
platform=tool_input.platform,
context=tool_input.context or Context(),
user_prompt=tool_input.user_prompt,
language=tool_input.language,
assistant=tool_input.assistant,
device_id=tool_input.device_id,
return await tool.async_call(
self.hass,
replace(
tool_input,
tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args),
context=tool_input.context or Context(),
),
)
return await tool.async_call(self.hass, _tool_input)
class IntentTool(Tool):
"""LLM Tool representing an Intent."""
@@ -209,28 +215,51 @@ class AssistAPI(API):
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent."
)
if tool_input.assistant:
exposed_entities: dict | None = _get_exposed_entities(
self.hass, tool_input.assistant
)
else:
exposed_entities = None
if not exposed_entities:
return (
"Only if the user wants to control a device, tell them to expose entities "
"to their voice assistant in Home Assistant."
)
prompt = [
(
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
)
]
if tool_input.device_id:
device_reg = device_registry.async_get(self.hass)
device_reg = dr.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id)
if device:
area_reg = area_registry.async_get(self.hass)
area_reg = ar.async_get(self.hass)
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
floor_reg = floor_registry.async_get(self.hass)
floor_reg = fr.async_get(self.hass)
if area.floor_id and (
floor := floor_reg.async_get_floor(area.floor_id)
):
prompt += f" You are in {area.name} ({floor.name})."
prompt.append(f"You are in {area.name} ({floor.name}).")
else:
prompt += f" You are in {area.name}."
prompt.append(f"You are in {area.name}.")
if tool_input.context and tool_input.context.user_id:
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
if user:
prompt += f" The user name is {user.name}."
return prompt
prompt.append(f"The user name is {user.name}.")
if exposed_entities:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
prompt.append(yaml.dump(exposed_entities))
return "\n".join(prompt)
@callback
def async_get_tools(self) -> list[Tool]:
@@ -240,3 +269,84 @@ class AssistAPI(API):
for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in self.IGNORE_INTENTS
]
def _get_exposed_entities(
hass: HomeAssistant, assistant: str
) -> dict[str, dict[str, Any]]:
"""Get exposed entities."""
area_registry = ar.async_get(hass)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
interesting_domains = {
"binary_sensor",
"climate",
"cover",
"fan",
"light",
"lock",
"sensor",
"switch",
"weather",
}
interesting_attributes = {
"temperature",
"current_temperature",
"temperature_unit",
"brightness",
"humidity",
"unit_of_measurement",
"device_class",
"current_position",
"percentage",
}
entities = {}
for state in hass.states.async_all():
if state.domain not in interesting_domains:
continue
if not async_should_expose(hass, assistant, state.entity_id):
continue
entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity_entry is not None:
names.extend(entity_entry.aliases)
if entity_entry.area_id and (
area := area_registry.async_get_area(entity_entry.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity_entry.device_id and (
device := device_registry.async_get(entity_entry.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
info: dict[str, Any] = {
"names": ", ".join(names),
"state": state.state,
}
if area_names:
info["areas"] = ", ".join(area_names)
if attributes := {
attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}:
info["attributes"] = attributes
entities[state.entity_id] = info
return entities