Add exposed entities to the Assist LLM API prompt (#118203)

* Add exposed entities to the Assist LLM API prompt

* Check expose entities in Google test

* Copy Google default prompt test cases to LLM tests
This commit is contained in:
Paulus Schoutsen 2024-05-27 00:27:08 -04:00 committed by GitHub
parent c391d73fec
commit ecb05989ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 526 additions and 88 deletions

View File

@ -3,7 +3,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, replace
from enum import Enum
from typing import Any
import voluptuous as vol
@ -13,12 +14,20 @@ from homeassistant.components.conversation.trace import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml
from homeassistant.util.json import JsonObjectType
from . import area_registry, device_registry, floor_registry, intent
from . import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
intent,
)
from .singleton import singleton
LLM_API_ASSIST = "assist"
@ -140,19 +149,16 @@ class API(ABC):
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
_tool_input = ToolInput(
return await tool.async_call(
self.hass,
replace(
tool_input,
tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args),
platform=tool_input.platform,
context=tool_input.context or Context(),
user_prompt=tool_input.user_prompt,
language=tool_input.language,
assistant=tool_input.assistant,
device_id=tool_input.device_id,
),
)
return await tool.async_call(self.hass, _tool_input)
class IntentTool(Tool):
"""LLM Tool representing an Intent."""
@ -209,28 +215,51 @@ class AssistAPI(API):
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
prompt = (
if tool_input.assistant:
exposed_entities: dict | None = _get_exposed_entities(
self.hass, tool_input.assistant
)
else:
exposed_entities = None
if not exposed_entities:
return (
"Only if the user wants to control a device, tell them to expose entities "
"to their voice assistant in Home Assistant."
)
prompt = [
(
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
)
]
if tool_input.device_id:
device_reg = device_registry.async_get(self.hass)
device_reg = dr.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id)
if device:
area_reg = area_registry.async_get(self.hass)
area_reg = ar.async_get(self.hass)
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
floor_reg = floor_registry.async_get(self.hass)
floor_reg = fr.async_get(self.hass)
if area.floor_id and (
floor := floor_reg.async_get_floor(area.floor_id)
):
prompt += f" You are in {area.name} ({floor.name})."
prompt.append(f"You are in {area.name} ({floor.name}).")
else:
prompt += f" You are in {area.name}."
prompt.append(f"You are in {area.name}.")
if tool_input.context and tool_input.context.user_id:
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
if user:
prompt += f" The user name is {user.name}."
return prompt
prompt.append(f"The user name is {user.name}.")
if exposed_entities:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
prompt.append(yaml.dump(exposed_entities))
return "\n".join(prompt)
@callback
def async_get_tools(self) -> list[Tool]:
@ -240,3 +269,84 @@ class AssistAPI(API):
for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in self.IGNORE_INTENTS
]
def _get_exposed_entities(
hass: HomeAssistant, assistant: str
) -> dict[str, dict[str, Any]]:
"""Get exposed entities."""
area_registry = ar.async_get(hass)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
interesting_domains = {
"binary_sensor",
"climate",
"cover",
"fan",
"light",
"lock",
"sensor",
"switch",
"weather",
}
interesting_attributes = {
"temperature",
"current_temperature",
"temperature_unit",
"brightness",
"humidity",
"unit_of_measurement",
"device_class",
"current_position",
"percentage",
}
entities = {}
for state in hass.states.async_all():
if state.domain not in interesting_domains:
continue
if not async_should_expose(hass, assistant, state.entity_id):
continue
entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity_entry is not None:
names.extend(entity_entry.aliases)
if entity_entry.area_id and (
area := area_registry.async_get_area(entity_entry.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity_entry.device_id and (
device := device_registry.async_get(entity_entry.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
info: dict[str, Any] = {
"names": ", ".join(names),
"state": state.state,
}
if area_names:
info["areas"] = ", ".join(area_names)
if attributes := {
attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}:
info["attributes"] = attributes
entities[state.entity_id] = info
return entities

View File

@ -262,7 +262,49 @@
Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00.
Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
''',
'role': 'user',
}),
@ -318,7 +360,49 @@
Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00.
Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
''',
'role': 'user',
}),

View File

@ -17,11 +17,13 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
llm,
)
from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
@pytest.fixture(autouse=True)
@ -47,9 +49,11 @@ async def test_default_prompt(
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
snapshot: SnapshotAssertion,
agent_id: str | None,
config_entry_options: {},
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test that the default prompt works."""
entry = MockConfigEntry(title=None)
@ -64,7 +68,22 @@ async def test_default_prompt(
mock_config_entry,
options={**mock_config_entry.options, **config_entry_options},
)
entities = []
def create_entity(device: dr.DeviceEntry) -> None:
"""Create an entity for a device and track entity_id."""
entity = entity_registry.async_get_or_create(
"light",
"test",
device.id,
device_id=device.id,
original_name=str(device.name),
suggested_object_id=str(device.name),
)
entity.write_unavailable_state(hass)
entities.append(entity.entity_id)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
@ -73,7 +92,9 @@ async def test_default_prompt(
model="Test Model",
suggested_area="Test Area",
)
)
for i in range(3):
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
@ -83,6 +104,8 @@ async def test_default_prompt(
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
@ -91,6 +114,8 @@ async def test_default_prompt(
model="Device 2",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
@ -99,12 +124,15 @@ async def test_default_prompt(
model="Test Model 3A",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
@ -116,6 +144,8 @@ async def test_default_prompt(
device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER
)
create_entity(device)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
@ -123,6 +153,8 @@ async def test_default_prompt(
model="Test Model NoName",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
@ -131,6 +163,21 @@ async def test_default_prompt(
model=3,
suggested_area="Test Area 2",
)
)
# Set options for registered entities
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "homeassistant/expose_entity",
"assistants": ["conversation"],
"entity_ids": entities,
"should_expose": True,
}
)
response = await ws_client.receive_json()
assert response["success"]
with (
patch("google.generativeai.GenerativeModel") as mock_model,
patch(

View File

@ -11,10 +11,13 @@ from homeassistant.helpers import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
intent,
llm,
)
from homeassistant.setup import async_setup_component
from homeassistant.util import yaml
from tests.common import MockConfigEntry
@ -158,10 +161,12 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
async def test_assist_api_prompt(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry,
floor_registry: fr.FloorRegistry,
) -> None:
"""Test prompt for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
context = Context()
tool_input = llm.ToolInput(
tool_name=None,
@ -170,41 +175,232 @@ async def test_assist_api_prompt(
context=context,
user_prompt="test_text",
language="*",
assistant="test_assistant",
assistant="conversation",
device_id="test_device",
)
api = llm.async_get_api(hass, "assist")
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent."
"Only if the user wants to control a device, tell them to expose entities to their "
"voice assistant in Home Assistant."
)
# Expose entities
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
tool_input.device_id = device_registry.async_get_or_create(
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
suggested_area="Test Area",
)
area = area_registry.async_get_area_by_name("Test Area")
area_registry.async_update(area.id, aliases=["Alternative name"])
entry1 = entity_registry.async_get_or_create(
"light",
"kitchen",
"mock-id-kitchen",
original_name="Kitchen",
suggested_object_id="kitchen",
)
entry2 = entity_registry.async_get_or_create(
"light",
"living_room",
"mock-id-living-room",
original_name="Living Room",
suggested_object_id="living_room",
device_id=device.id,
)
hass.states.async_set(entry1.entity_id, "on", {"friendly_name": "Kitchen"})
hass.states.async_set(entry2.entity_id, "on", {"friendly_name": "Living Room"})
def create_entity(device: dr.DeviceEntry, write_state=True) -> None:
"""Create an entity for a device and track entity_id."""
entity = entity_registry.async_get_or_create(
"light",
"test",
device.id,
device_id=device.id,
original_name=str(device.name or "Unnamed Device"),
suggested_object_id=str(device.name or "unnamed_device"),
)
if write_state:
entity.write_unavailable_state(hass)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
name="Test Device",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
).id
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent. You are in Test Area."
)
)
for i in range(3):
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
)
device2 = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3 - disabled",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device2.id, disabled_by=dr.DeviceEntryDisabler.USER
)
create_entity(device2, False)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
)
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant)
assert exposed_entities == {
"light.1": {
"areas": "Test Area 2",
"names": "1",
"state": "unavailable",
},
entry1.entity_id: {
"names": "Kitchen",
"state": "on",
},
entry2.entity_id: {
"areas": "Test Area, Alternative name",
"names": "Living Room",
"state": "on",
},
"light.test_device": {
"areas": "Test Area, Alternative name",
"names": "Test Device",
"state": "unavailable",
},
"light.test_device_2": {
"areas": "Test Area 2",
"names": "Test Device 2",
"state": "unavailable",
},
"light.test_device_3": {
"areas": "Test Area 2",
"names": "Test Device 3",
"state": "unavailable",
},
"light.test_device_4": {
"areas": "Test Area 2",
"names": "Test Device 4",
"state": "unavailable",
},
"light.test_service": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.test_service_2": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.test_service_3": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.unnamed_device": {
"areas": "Test Area 2",
"names": "Unnamed Device",
"state": "unavailable",
},
}
exposed_entities_prompt = (
"An overview of the areas and the devices in this smart home:\n"
+ yaml.dump(exposed_entities)
)
first_part_prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
)
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
f"""{first_part_prompt}
{exposed_entities_prompt}"""
)
# Fake that request is made from a specific device ID
tool_input.device_id = device.id
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
f"""{first_part_prompt}
You are in Test Area.
{exposed_entities_prompt}"""
)
# Add floor
floor = floor_registry.async_create("second floor")
area = area_registry.async_get_area_by_name("Test Area")
area_registry.async_update(area.id, floor_id=floor.floor_id)
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent. You are in Test Area (second floor)."
f"""{first_part_prompt}
You are in Test Area (second floor).
{exposed_entities_prompt}"""
)
# Add user
context.user_id = "12345"
mock_user = Mock()
mock_user.id = "12345"
@ -212,7 +408,8 @@ async def test_assist_api_prompt(
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
"Call the intent tools to control Home Assistant."
" Just pass the name to the intent. You are in Test Area (second floor)."
" The user name is Test User."
f"""{first_part_prompt}
You are in Test Area (second floor).
The user name is Test User.
{exposed_entities_prompt}"""
)