mirror of
https://github.com/home-assistant/core.git
synced 2025-11-10 03:19:34 +00:00
Add exposed entities to the Assist LLM API prompt (#118203)
* Add exposed entities to the Assist LLM API prompt * Check expose entities in Google test * Copy Google default prompt test cases to LLM tests
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
@@ -13,12 +14,20 @@ from homeassistant.components.conversation.trace import (
|
||||
ConversationTraceEventType,
|
||||
async_conversation_trace_append,
|
||||
)
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import yaml
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from . import area_registry, device_registry, floor_registry, intent
|
||||
from . import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
floor_registry as fr,
|
||||
intent,
|
||||
)
|
||||
from .singleton import singleton
|
||||
|
||||
LLM_API_ASSIST = "assist"
|
||||
@@ -140,19 +149,16 @@ class API(ABC):
|
||||
else:
|
||||
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
||||
|
||||
_tool_input = ToolInput(
|
||||
tool_name=tool.name,
|
||||
tool_args=tool.parameters(tool_input.tool_args),
|
||||
platform=tool_input.platform,
|
||||
context=tool_input.context or Context(),
|
||||
user_prompt=tool_input.user_prompt,
|
||||
language=tool_input.language,
|
||||
assistant=tool_input.assistant,
|
||||
device_id=tool_input.device_id,
|
||||
return await tool.async_call(
|
||||
self.hass,
|
||||
replace(
|
||||
tool_input,
|
||||
tool_name=tool.name,
|
||||
tool_args=tool.parameters(tool_input.tool_args),
|
||||
context=tool_input.context or Context(),
|
||||
),
|
||||
)
|
||||
|
||||
return await tool.async_call(self.hass, _tool_input)
|
||||
|
||||
|
||||
class IntentTool(Tool):
|
||||
"""LLM Tool representing an Intent."""
|
||||
@@ -209,28 +215,51 @@ class AssistAPI(API):
|
||||
|
||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
||||
"""Return the prompt for the API."""
|
||||
prompt = (
|
||||
"Call the intent tools to control Home Assistant. "
|
||||
"Just pass the name to the intent."
|
||||
)
|
||||
if tool_input.assistant:
|
||||
exposed_entities: dict | None = _get_exposed_entities(
|
||||
self.hass, tool_input.assistant
|
||||
)
|
||||
else:
|
||||
exposed_entities = None
|
||||
|
||||
if not exposed_entities:
|
||||
return (
|
||||
"Only if the user wants to control a device, tell them to expose entities "
|
||||
"to their voice assistant in Home Assistant."
|
||||
)
|
||||
|
||||
prompt = [
|
||||
(
|
||||
"Call the intent tools to control Home Assistant. "
|
||||
"Just pass the name to the intent. "
|
||||
"When controlling an area, prefer passing area name."
|
||||
)
|
||||
]
|
||||
if tool_input.device_id:
|
||||
device_reg = device_registry.async_get(self.hass)
|
||||
device_reg = dr.async_get(self.hass)
|
||||
device = device_reg.async_get(tool_input.device_id)
|
||||
if device:
|
||||
area_reg = area_registry.async_get(self.hass)
|
||||
area_reg = ar.async_get(self.hass)
|
||||
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
|
||||
floor_reg = floor_registry.async_get(self.hass)
|
||||
floor_reg = fr.async_get(self.hass)
|
||||
if area.floor_id and (
|
||||
floor := floor_reg.async_get_floor(area.floor_id)
|
||||
):
|
||||
prompt += f" You are in {area.name} ({floor.name})."
|
||||
prompt.append(f"You are in {area.name} ({floor.name}).")
|
||||
else:
|
||||
prompt += f" You are in {area.name}."
|
||||
prompt.append(f"You are in {area.name}.")
|
||||
if tool_input.context and tool_input.context.user_id:
|
||||
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
|
||||
if user:
|
||||
prompt += f" The user name is {user.name}."
|
||||
return prompt
|
||||
prompt.append(f"The user name is {user.name}.")
|
||||
|
||||
if exposed_entities:
|
||||
prompt.append(
|
||||
"An overview of the areas and the devices in this smart home:"
|
||||
)
|
||||
prompt.append(yaml.dump(exposed_entities))
|
||||
|
||||
return "\n".join(prompt)
|
||||
|
||||
@callback
|
||||
def async_get_tools(self) -> list[Tool]:
|
||||
@@ -240,3 +269,84 @@ class AssistAPI(API):
|
||||
for intent_handler in intent.async_get(self.hass)
|
||||
if intent_handler.intent_type not in self.IGNORE_INTENTS
|
||||
]
|
||||
|
||||
|
||||
def _get_exposed_entities(
|
||||
hass: HomeAssistant, assistant: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get exposed entities."""
|
||||
area_registry = ar.async_get(hass)
|
||||
entity_registry = er.async_get(hass)
|
||||
device_registry = dr.async_get(hass)
|
||||
interesting_domains = {
|
||||
"binary_sensor",
|
||||
"climate",
|
||||
"cover",
|
||||
"fan",
|
||||
"light",
|
||||
"lock",
|
||||
"sensor",
|
||||
"switch",
|
||||
"weather",
|
||||
}
|
||||
interesting_attributes = {
|
||||
"temperature",
|
||||
"current_temperature",
|
||||
"temperature_unit",
|
||||
"brightness",
|
||||
"humidity",
|
||||
"unit_of_measurement",
|
||||
"device_class",
|
||||
"current_position",
|
||||
"percentage",
|
||||
}
|
||||
|
||||
entities = {}
|
||||
|
||||
for state in hass.states.async_all():
|
||||
if state.domain not in interesting_domains:
|
||||
continue
|
||||
|
||||
if not async_should_expose(hass, assistant, state.entity_id):
|
||||
continue
|
||||
|
||||
entity_entry = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
|
||||
if entity_entry is not None:
|
||||
names.extend(entity_entry.aliases)
|
||||
if entity_entry.area_id and (
|
||||
area := area_registry.async_get_area(entity_entry.area_id)
|
||||
):
|
||||
# Entity is in area
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
elif entity_entry.device_id and (
|
||||
device := device_registry.async_get(entity_entry.device_id)
|
||||
):
|
||||
# Check device area
|
||||
if device.area_id and (
|
||||
area := area_registry.async_get_area(device.area_id)
|
||||
):
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
|
||||
info: dict[str, Any] = {
|
||||
"names": ", ".join(names),
|
||||
"state": state.state,
|
||||
}
|
||||
|
||||
if area_names:
|
||||
info["areas"] = ", ".join(area_names)
|
||||
|
||||
if attributes := {
|
||||
attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value
|
||||
for attr_name, attr_value in state.attributes.items()
|
||||
if attr_name in interesting_attributes
|
||||
}:
|
||||
info["attributes"] = attributes
|
||||
|
||||
entities[state.entity_id] = info
|
||||
|
||||
return entities
|
||||
|
||||
Reference in New Issue
Block a user