Add exposed entities to the Assist LLM API prompt (#118203)

* Add exposed entities to the Assist LLM API prompt

* Check expose entities in Google test

* Copy Google default prompt test cases to LLM tests
This commit is contained in:
Paulus Schoutsen 2024-05-27 00:27:08 -04:00 committed by GitHub
parent c391d73fec
commit ecb05989ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 526 additions and 88 deletions

View File

@ -3,7 +3,8 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, replace
from enum import Enum
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
@ -13,12 +14,20 @@ from homeassistant.components.conversation.trace import (
ConversationTraceEventType, ConversationTraceEventType,
async_conversation_trace_append, async_conversation_trace_append,
) )
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
from . import area_registry, device_registry, floor_registry, intent from . import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
intent,
)
from .singleton import singleton from .singleton import singleton
LLM_API_ASSIST = "assist" LLM_API_ASSIST = "assist"
@ -140,19 +149,16 @@ class API(ABC):
else: else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found') raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
_tool_input = ToolInput( return await tool.async_call(
self.hass,
replace(
tool_input,
tool_name=tool.name, tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args), tool_args=tool.parameters(tool_input.tool_args),
platform=tool_input.platform,
context=tool_input.context or Context(), context=tool_input.context or Context(),
user_prompt=tool_input.user_prompt, ),
language=tool_input.language,
assistant=tool_input.assistant,
device_id=tool_input.device_id,
) )
return await tool.async_call(self.hass, _tool_input)
class IntentTool(Tool): class IntentTool(Tool):
"""LLM Tool representing an Intent.""" """LLM Tool representing an Intent."""
@ -209,28 +215,51 @@ class AssistAPI(API):
async def async_get_api_prompt(self, tool_input: ToolInput) -> str: async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API.""" """Return the prompt for the API."""
prompt = ( if tool_input.assistant:
exposed_entities: dict | None = _get_exposed_entities(
self.hass, tool_input.assistant
)
else:
exposed_entities = None
if not exposed_entities:
return (
"Only if the user wants to control a device, tell them to expose entities "
"to their voice assistant in Home Assistant."
)
prompt = [
(
"Call the intent tools to control Home Assistant. " "Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. " "Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
) )
]
if tool_input.device_id: if tool_input.device_id:
device_reg = device_registry.async_get(self.hass) device_reg = dr.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id) device = device_reg.async_get(tool_input.device_id)
if device: if device:
area_reg = area_registry.async_get(self.hass) area_reg = ar.async_get(self.hass)
if device.area_id and (area := area_reg.async_get_area(device.area_id)): if device.area_id and (area := area_reg.async_get_area(device.area_id)):
floor_reg = floor_registry.async_get(self.hass) floor_reg = fr.async_get(self.hass)
if area.floor_id and ( if area.floor_id and (
floor := floor_reg.async_get_floor(area.floor_id) floor := floor_reg.async_get_floor(area.floor_id)
): ):
prompt += f" You are in {area.name} ({floor.name})." prompt.append(f"You are in {area.name} ({floor.name}).")
else: else:
prompt += f" You are in {area.name}." prompt.append(f"You are in {area.name}.")
if tool_input.context and tool_input.context.user_id: if tool_input.context and tool_input.context.user_id:
user = await self.hass.auth.async_get_user(tool_input.context.user_id) user = await self.hass.auth.async_get_user(tool_input.context.user_id)
if user: if user:
prompt += f" The user name is {user.name}." prompt.append(f"The user name is {user.name}.")
return prompt
if exposed_entities:
prompt.append(
"An overview of the areas and the devices in this smart home:"
)
prompt.append(yaml.dump(exposed_entities))
return "\n".join(prompt)
@callback @callback
def async_get_tools(self) -> list[Tool]: def async_get_tools(self) -> list[Tool]:
@ -240,3 +269,84 @@ class AssistAPI(API):
for intent_handler in intent.async_get(self.hass) for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in self.IGNORE_INTENTS if intent_handler.intent_type not in self.IGNORE_INTENTS
] ]
def _get_exposed_entities(
hass: HomeAssistant, assistant: str
) -> dict[str, dict[str, Any]]:
"""Get exposed entities."""
area_registry = ar.async_get(hass)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
interesting_domains = {
"binary_sensor",
"climate",
"cover",
"fan",
"light",
"lock",
"sensor",
"switch",
"weather",
}
interesting_attributes = {
"temperature",
"current_temperature",
"temperature_unit",
"brightness",
"humidity",
"unit_of_measurement",
"device_class",
"current_position",
"percentage",
}
entities = {}
for state in hass.states.async_all():
if state.domain not in interesting_domains:
continue
if not async_should_expose(hass, assistant, state.entity_id):
continue
entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity_entry is not None:
names.extend(entity_entry.aliases)
if entity_entry.area_id and (
area := area_registry.async_get_area(entity_entry.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity_entry.device_id and (
device := device_registry.async_get(entity_entry.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
info: dict[str, Any] = {
"names": ", ".join(names),
"state": state.state,
}
if area_names:
info["areas"] = ", ".join(area_names)
if attributes := {
attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value
for attr_name, attr_value in state.attributes.items()
if attr_name in interesting_attributes
}:
info["attributes"] = attributes
entities[state.entity_id] = info
return entities

View File

@ -262,7 +262,49 @@
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00. The current time is 05:00:00.
Today's date is 05/24/24. Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent. Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
''', ''',
'role': 'user', 'role': 'user',
}), }),
@ -318,7 +360,49 @@
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
The current time is 05:00:00. The current time is 05:00:00.
Today's date is 05/24/24. Today's date is 05/24/24.
Call the intent tools to control Home Assistant. Just pass the name to the intent. Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
An overview of the areas and the devices in this smart home:
light.test_device:
names: Test Device
state: unavailable
areas: Test Area
light.test_service:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_2:
names: Test Service
state: unavailable
areas: Test Area
light.test_service_3:
names: Test Service
state: unavailable
areas: Test Area
light.test_device_2:
names: Test Device 2
state: unavailable
areas: Test Area 2
light.test_device_3:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.test_device_4:
names: Test Device 4
state: unavailable
areas: Test Area 2
light.test_device_3_2:
names: Test Device 3
state: unavailable
areas: Test Area 2
light.none:
names: None
state: unavailable
areas: Test Area 2
light.1:
names: '1'
state: unavailable
areas: Test Area 2
''', ''',
'role': 'user', 'role': 'user',
}), }),

View File

@ -17,11 +17,13 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
device_registry as dr, device_registry as dr,
entity_registry as er,
intent, intent,
llm, llm,
) )
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -47,9 +49,11 @@ async def test_default_prompt(
mock_init_component, mock_init_component,
area_registry: ar.AreaRegistry, area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
agent_id: str | None, agent_id: str | None,
config_entry_options: {}, config_entry_options: {},
hass_ws_client: WebSocketGenerator,
) -> None: ) -> None:
"""Test that the default prompt works.""" """Test that the default prompt works."""
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
@ -64,7 +68,22 @@ async def test_default_prompt(
mock_config_entry, mock_config_entry,
options={**mock_config_entry.options, **config_entry_options}, options={**mock_config_entry.options, **config_entry_options},
) )
entities = []
def create_entity(device: dr.DeviceEntry) -> None:
"""Create an entity for a device and track entity_id."""
entity = entity_registry.async_get_or_create(
"light",
"test",
device.id,
device_id=device.id,
original_name=str(device.name),
suggested_object_id=str(device.name),
)
entity.write_unavailable_state(hass)
entities.append(entity.entity_id)
create_entity(
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "1234")}, connections={("test", "1234")},
@ -73,7 +92,9 @@ async def test_default_prompt(
model="Test Model", model="Test Model",
suggested_area="Test Area", suggested_area="Test Area",
) )
)
for i in range(3): for i in range(3):
create_entity(
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")}, connections={("test", f"{i}abcd")},
@ -83,6 +104,8 @@ async def test_default_prompt(
suggested_area="Test Area", suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE, entry_type=dr.DeviceEntryType.SERVICE,
) )
)
create_entity(
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "5678")}, connections={("test", "5678")},
@ -91,6 +114,8 @@ async def test_default_prompt(
model="Device 2", model="Device 2",
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
)
create_entity(
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "9876")}, connections={("test", "9876")},
@ -99,12 +124,15 @@ async def test_default_prompt(
model="Test Model 3A", model="Test Model 3A",
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
)
create_entity(
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "qwer")}, connections={("test", "qwer")},
name="Test Device 4", name="Test Device 4",
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
)
device = device_registry.async_get_or_create( device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")}, connections={("test", "9876-disabled")},
@ -116,6 +144,8 @@ async def test_default_prompt(
device_registry.async_update_device( device_registry.async_update_device(
device.id, disabled_by=dr.DeviceEntryDisabler.USER device.id, disabled_by=dr.DeviceEntryDisabler.USER
) )
create_entity(device)
create_entity(
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")}, connections={("test", "9876-no-name")},
@ -123,6 +153,8 @@ async def test_default_prompt(
model="Test Model NoName", model="Test Model NoName",
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
)
create_entity(
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")}, connections={("test", "9876-integer-values")},
@ -131,6 +163,21 @@ async def test_default_prompt(
model=3, model=3,
suggested_area="Test Area 2", suggested_area="Test Area 2",
) )
)
# Set options for registered entities
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "homeassistant/expose_entity",
"assistants": ["conversation"],
"entity_ids": entities,
"should_expose": True,
}
)
response = await ws_client.receive_json()
assert response["success"]
with ( with (
patch("google.generativeai.GenerativeModel") as mock_model, patch("google.generativeai.GenerativeModel") as mock_model,
patch( patch(

View File

@ -11,10 +11,13 @@ from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
config_validation as cv, config_validation as cv,
device_registry as dr, device_registry as dr,
entity_registry as er,
floor_registry as fr, floor_registry as fr,
intent, intent,
llm, llm,
) )
from homeassistant.setup import async_setup_component
from homeassistant.util import yaml
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -158,10 +161,12 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
async def test_assist_api_prompt( async def test_assist_api_prompt(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry, area_registry: ar.AreaRegistry,
floor_registry: fr.FloorRegistry, floor_registry: fr.FloorRegistry,
) -> None: ) -> None:
"""Test prompt for the assist API.""" """Test prompt for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
context = Context() context = Context()
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name=None, tool_name=None,
@ -170,41 +175,232 @@ async def test_assist_api_prompt(
context=context, context=context,
user_prompt="test_text", user_prompt="test_text",
language="*", language="*",
assistant="test_assistant", assistant="conversation",
device_id="test_device", device_id="test_device",
) )
api = llm.async_get_api(hass, "assist") api = llm.async_get_api(hass, "assist")
prompt = await api.async_get_api_prompt(tool_input) prompt = await api.async_get_api_prompt(tool_input)
assert prompt == ( assert prompt == (
"Call the intent tools to control Home Assistant." "Only if the user wants to control a device, tell them to expose entities to their "
" Just pass the name to the intent." "voice assistant in Home Assistant."
) )
# Expose entities
entry = MockConfigEntry(title=None) entry = MockConfigEntry(title=None)
entry.add_to_hass(hass) entry.add_to_hass(hass)
tool_input.device_id = device_registry.async_get_or_create( device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
suggested_area="Test Area",
)
area = area_registry.async_get_area_by_name("Test Area")
area_registry.async_update(area.id, aliases=["Alternative name"])
entry1 = entity_registry.async_get_or_create(
"light",
"kitchen",
"mock-id-kitchen",
original_name="Kitchen",
suggested_object_id="kitchen",
)
entry2 = entity_registry.async_get_or_create(
"light",
"living_room",
"mock-id-living-room",
original_name="Living Room",
suggested_object_id="living_room",
device_id=device.id,
)
hass.states.async_set(entry1.entity_id, "on", {"friendly_name": "Kitchen"})
hass.states.async_set(entry2.entity_id, "on", {"friendly_name": "Living Room"})
def create_entity(device: dr.DeviceEntry, write_state=True) -> None:
"""Create an entity for a device and track entity_id."""
entity = entity_registry.async_get_or_create(
"light",
"test",
device.id,
device_id=device.id,
original_name=str(device.name or "Unnamed Device"),
suggested_object_id=str(device.name or "unnamed_device"),
)
if write_state:
entity.write_unavailable_state(hass)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id, config_entry_id=entry.entry_id,
connections={("test", "1234")}, connections={("test", "1234")},
name="Test Device", name="Test Device",
manufacturer="Test Manufacturer", manufacturer="Test Manufacturer",
model="Test Model", model="Test Model",
suggested_area="Test Area", suggested_area="Test Area",
).id )
prompt = await api.async_get_api_prompt(tool_input) )
assert prompt == ( for i in range(3):
"Call the intent tools to control Home Assistant." create_entity(
" Just pass the name to the intent. You are in Test Area." device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", f"{i}abcd")},
name="Test Service",
manufacturer="Test Manufacturer",
model="Test Model",
suggested_area="Test Area",
entry_type=dr.DeviceEntryType.SERVICE,
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "5678")},
name="Test Device 2",
manufacturer="Test Manufacturer 2",
model="Device 2",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876")},
name="Test Device 3",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "qwer")},
name="Test Device 4",
suggested_area="Test Area 2",
)
)
device2 = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-disabled")},
name="Test Device 3 - disabled",
manufacturer="Test Manufacturer 3",
model="Test Model 3A",
suggested_area="Test Area 2",
)
device_registry.async_update_device(
device2.id, disabled_by=dr.DeviceEntryDisabler.USER
)
create_entity(device2, False)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-no-name")},
manufacturer="Test Manufacturer NoName",
model="Test Model NoName",
suggested_area="Test Area 2",
)
)
create_entity(
device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "9876-integer-values")},
name=1,
manufacturer=2,
model=3,
suggested_area="Test Area 2",
)
) )
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant)
assert exposed_entities == {
"light.1": {
"areas": "Test Area 2",
"names": "1",
"state": "unavailable",
},
entry1.entity_id: {
"names": "Kitchen",
"state": "on",
},
entry2.entity_id: {
"areas": "Test Area, Alternative name",
"names": "Living Room",
"state": "on",
},
"light.test_device": {
"areas": "Test Area, Alternative name",
"names": "Test Device",
"state": "unavailable",
},
"light.test_device_2": {
"areas": "Test Area 2",
"names": "Test Device 2",
"state": "unavailable",
},
"light.test_device_3": {
"areas": "Test Area 2",
"names": "Test Device 3",
"state": "unavailable",
},
"light.test_device_4": {
"areas": "Test Area 2",
"names": "Test Device 4",
"state": "unavailable",
},
"light.test_service": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.test_service_2": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.test_service_3": {
"areas": "Test Area, Alternative name",
"names": "Test Service",
"state": "unavailable",
},
"light.unnamed_device": {
"areas": "Test Area 2",
"names": "Unnamed Device",
"state": "unavailable",
},
}
exposed_entities_prompt = (
"An overview of the areas and the devices in this smart home:\n"
+ yaml.dump(exposed_entities)
)
first_part_prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent. "
"When controlling an area, prefer passing area name."
)
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
f"""{first_part_prompt}
{exposed_entities_prompt}"""
)
# Fake that request is made from a specific device ID
tool_input.device_id = device.id
prompt = await api.async_get_api_prompt(tool_input)
assert prompt == (
f"""{first_part_prompt}
You are in Test Area.
{exposed_entities_prompt}"""
)
# Add floor
floor = floor_registry.async_create("second floor") floor = floor_registry.async_create("second floor")
area = area_registry.async_get_area_by_name("Test Area")
area_registry.async_update(area.id, floor_id=floor.floor_id) area_registry.async_update(area.id, floor_id=floor.floor_id)
prompt = await api.async_get_api_prompt(tool_input) prompt = await api.async_get_api_prompt(tool_input)
assert prompt == ( assert prompt == (
"Call the intent tools to control Home Assistant." f"""{first_part_prompt}
" Just pass the name to the intent. You are in Test Area (second floor)." You are in Test Area (second floor).
{exposed_entities_prompt}"""
) )
# Add user
context.user_id = "12345" context.user_id = "12345"
mock_user = Mock() mock_user = Mock()
mock_user.id = "12345" mock_user.id = "12345"
@ -212,7 +408,8 @@ async def test_assist_api_prompt(
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user): with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
prompt = await api.async_get_api_prompt(tool_input) prompt = await api.async_get_api_prompt(tool_input)
assert prompt == ( assert prompt == (
"Call the intent tools to control Home Assistant." f"""{first_part_prompt}
" Just pass the name to the intent. You are in Test Area (second floor)." You are in Test Area (second floor).
" The user name is Test User." The user name is Test User.
{exposed_entities_prompt}"""
) )