mirror of
https://github.com/home-assistant/core.git
synced 2025-04-27 10:47:51 +00:00
Add exposed entities to the Assist LLM API prompt (#118203)
* Add exposed entities to the Assist LLM API prompt * Check expose entities in Google test * Copy Google default prompt test cases to LLM tests
This commit is contained in:
parent
c391d73fec
commit
ecb05989ca
@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass, replace
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -13,12 +14,20 @@ from homeassistant.components.conversation.trace import (
|
|||||||
ConversationTraceEventType,
|
ConversationTraceEventType,
|
||||||
async_conversation_trace_append,
|
async_conversation_trace_append,
|
||||||
)
|
)
|
||||||
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||||
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.util import yaml
|
||||||
from homeassistant.util.json import JsonObjectType
|
from homeassistant.util.json import JsonObjectType
|
||||||
|
|
||||||
from . import area_registry, device_registry, floor_registry, intent
|
from . import (
|
||||||
|
area_registry as ar,
|
||||||
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
|
floor_registry as fr,
|
||||||
|
intent,
|
||||||
|
)
|
||||||
from .singleton import singleton
|
from .singleton import singleton
|
||||||
|
|
||||||
LLM_API_ASSIST = "assist"
|
LLM_API_ASSIST = "assist"
|
||||||
@ -140,19 +149,16 @@ class API(ABC):
|
|||||||
else:
|
else:
|
||||||
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
|
||||||
|
|
||||||
_tool_input = ToolInput(
|
return await tool.async_call(
|
||||||
|
self.hass,
|
||||||
|
replace(
|
||||||
|
tool_input,
|
||||||
tool_name=tool.name,
|
tool_name=tool.name,
|
||||||
tool_args=tool.parameters(tool_input.tool_args),
|
tool_args=tool.parameters(tool_input.tool_args),
|
||||||
platform=tool_input.platform,
|
|
||||||
context=tool_input.context or Context(),
|
context=tool_input.context or Context(),
|
||||||
user_prompt=tool_input.user_prompt,
|
),
|
||||||
language=tool_input.language,
|
|
||||||
assistant=tool_input.assistant,
|
|
||||||
device_id=tool_input.device_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await tool.async_call(self.hass, _tool_input)
|
|
||||||
|
|
||||||
|
|
||||||
class IntentTool(Tool):
|
class IntentTool(Tool):
|
||||||
"""LLM Tool representing an Intent."""
|
"""LLM Tool representing an Intent."""
|
||||||
@ -209,28 +215,51 @@ class AssistAPI(API):
|
|||||||
|
|
||||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
||||||
"""Return the prompt for the API."""
|
"""Return the prompt for the API."""
|
||||||
prompt = (
|
if tool_input.assistant:
|
||||||
|
exposed_entities: dict | None = _get_exposed_entities(
|
||||||
|
self.hass, tool_input.assistant
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
exposed_entities = None
|
||||||
|
|
||||||
|
if not exposed_entities:
|
||||||
|
return (
|
||||||
|
"Only if the user wants to control a device, tell them to expose entities "
|
||||||
|
"to their voice assistant in Home Assistant."
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = [
|
||||||
|
(
|
||||||
"Call the intent tools to control Home Assistant. "
|
"Call the intent tools to control Home Assistant. "
|
||||||
"Just pass the name to the intent. "
|
"Just pass the name to the intent. "
|
||||||
|
"When controlling an area, prefer passing area name."
|
||||||
)
|
)
|
||||||
|
]
|
||||||
if tool_input.device_id:
|
if tool_input.device_id:
|
||||||
device_reg = device_registry.async_get(self.hass)
|
device_reg = dr.async_get(self.hass)
|
||||||
device = device_reg.async_get(tool_input.device_id)
|
device = device_reg.async_get(tool_input.device_id)
|
||||||
if device:
|
if device:
|
||||||
area_reg = area_registry.async_get(self.hass)
|
area_reg = ar.async_get(self.hass)
|
||||||
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
|
if device.area_id and (area := area_reg.async_get_area(device.area_id)):
|
||||||
floor_reg = floor_registry.async_get(self.hass)
|
floor_reg = fr.async_get(self.hass)
|
||||||
if area.floor_id and (
|
if area.floor_id and (
|
||||||
floor := floor_reg.async_get_floor(area.floor_id)
|
floor := floor_reg.async_get_floor(area.floor_id)
|
||||||
):
|
):
|
||||||
prompt += f" You are in {area.name} ({floor.name})."
|
prompt.append(f"You are in {area.name} ({floor.name}).")
|
||||||
else:
|
else:
|
||||||
prompt += f" You are in {area.name}."
|
prompt.append(f"You are in {area.name}.")
|
||||||
if tool_input.context and tool_input.context.user_id:
|
if tool_input.context and tool_input.context.user_id:
|
||||||
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
|
user = await self.hass.auth.async_get_user(tool_input.context.user_id)
|
||||||
if user:
|
if user:
|
||||||
prompt += f" The user name is {user.name}."
|
prompt.append(f"The user name is {user.name}.")
|
||||||
return prompt
|
|
||||||
|
if exposed_entities:
|
||||||
|
prompt.append(
|
||||||
|
"An overview of the areas and the devices in this smart home:"
|
||||||
|
)
|
||||||
|
prompt.append(yaml.dump(exposed_entities))
|
||||||
|
|
||||||
|
return "\n".join(prompt)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_tools(self) -> list[Tool]:
|
def async_get_tools(self) -> list[Tool]:
|
||||||
@ -240,3 +269,84 @@ class AssistAPI(API):
|
|||||||
for intent_handler in intent.async_get(self.hass)
|
for intent_handler in intent.async_get(self.hass)
|
||||||
if intent_handler.intent_type not in self.IGNORE_INTENTS
|
if intent_handler.intent_type not in self.IGNORE_INTENTS
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_exposed_entities(
|
||||||
|
hass: HomeAssistant, assistant: str
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get exposed entities."""
|
||||||
|
area_registry = ar.async_get(hass)
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
device_registry = dr.async_get(hass)
|
||||||
|
interesting_domains = {
|
||||||
|
"binary_sensor",
|
||||||
|
"climate",
|
||||||
|
"cover",
|
||||||
|
"fan",
|
||||||
|
"light",
|
||||||
|
"lock",
|
||||||
|
"sensor",
|
||||||
|
"switch",
|
||||||
|
"weather",
|
||||||
|
}
|
||||||
|
interesting_attributes = {
|
||||||
|
"temperature",
|
||||||
|
"current_temperature",
|
||||||
|
"temperature_unit",
|
||||||
|
"brightness",
|
||||||
|
"humidity",
|
||||||
|
"unit_of_measurement",
|
||||||
|
"device_class",
|
||||||
|
"current_position",
|
||||||
|
"percentage",
|
||||||
|
}
|
||||||
|
|
||||||
|
entities = {}
|
||||||
|
|
||||||
|
for state in hass.states.async_all():
|
||||||
|
if state.domain not in interesting_domains:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not async_should_expose(hass, assistant, state.entity_id):
|
||||||
|
continue
|
||||||
|
|
||||||
|
entity_entry = entity_registry.async_get(state.entity_id)
|
||||||
|
names = [state.name]
|
||||||
|
area_names = []
|
||||||
|
|
||||||
|
if entity_entry is not None:
|
||||||
|
names.extend(entity_entry.aliases)
|
||||||
|
if entity_entry.area_id and (
|
||||||
|
area := area_registry.async_get_area(entity_entry.area_id)
|
||||||
|
):
|
||||||
|
# Entity is in area
|
||||||
|
area_names.append(area.name)
|
||||||
|
area_names.extend(area.aliases)
|
||||||
|
elif entity_entry.device_id and (
|
||||||
|
device := device_registry.async_get(entity_entry.device_id)
|
||||||
|
):
|
||||||
|
# Check device area
|
||||||
|
if device.area_id and (
|
||||||
|
area := area_registry.async_get_area(device.area_id)
|
||||||
|
):
|
||||||
|
area_names.append(area.name)
|
||||||
|
area_names.extend(area.aliases)
|
||||||
|
|
||||||
|
info: dict[str, Any] = {
|
||||||
|
"names": ", ".join(names),
|
||||||
|
"state": state.state,
|
||||||
|
}
|
||||||
|
|
||||||
|
if area_names:
|
||||||
|
info["areas"] = ", ".join(area_names)
|
||||||
|
|
||||||
|
if attributes := {
|
||||||
|
attr_name: str(attr_value) if isinstance(attr_value, Enum) else attr_value
|
||||||
|
for attr_name, attr_value in state.attributes.items()
|
||||||
|
if attr_name in interesting_attributes
|
||||||
|
}:
|
||||||
|
info["attributes"] = attributes
|
||||||
|
|
||||||
|
entities[state.entity_id] = info
|
||||||
|
|
||||||
|
return entities
|
||||||
|
@ -262,7 +262,49 @@
|
|||||||
Answer in plain text. Keep it simple and to the point.
|
Answer in plain text. Keep it simple and to the point.
|
||||||
The current time is 05:00:00.
|
The current time is 05:00:00.
|
||||||
Today's date is 05/24/24.
|
Today's date is 05/24/24.
|
||||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
|
||||||
|
An overview of the areas and the devices in this smart home:
|
||||||
|
light.test_device:
|
||||||
|
names: Test Device
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_service:
|
||||||
|
names: Test Service
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_service_2:
|
||||||
|
names: Test Service
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_service_3:
|
||||||
|
names: Test Service
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_device_2:
|
||||||
|
names: Test Device 2
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.test_device_3:
|
||||||
|
names: Test Device 3
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.test_device_4:
|
||||||
|
names: Test Device 4
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.test_device_3_2:
|
||||||
|
names: Test Device 3
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.none:
|
||||||
|
names: None
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.1:
|
||||||
|
names: '1'
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
|
||||||
''',
|
''',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
@ -318,7 +360,49 @@
|
|||||||
Answer in plain text. Keep it simple and to the point.
|
Answer in plain text. Keep it simple and to the point.
|
||||||
The current time is 05:00:00.
|
The current time is 05:00:00.
|
||||||
Today's date is 05/24/24.
|
Today's date is 05/24/24.
|
||||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
Call the intent tools to control Home Assistant. Just pass the name to the intent. When controlling an area, prefer passing area name.
|
||||||
|
An overview of the areas and the devices in this smart home:
|
||||||
|
light.test_device:
|
||||||
|
names: Test Device
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_service:
|
||||||
|
names: Test Service
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_service_2:
|
||||||
|
names: Test Service
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_service_3:
|
||||||
|
names: Test Service
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area
|
||||||
|
light.test_device_2:
|
||||||
|
names: Test Device 2
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.test_device_3:
|
||||||
|
names: Test Device 3
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.test_device_4:
|
||||||
|
names: Test Device 4
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.test_device_3_2:
|
||||||
|
names: Test Device 3
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.none:
|
||||||
|
names: None
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
light.1:
|
||||||
|
names: '1'
|
||||||
|
state: unavailable
|
||||||
|
areas: Test Area 2
|
||||||
|
|
||||||
''',
|
''',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
|
@ -17,11 +17,13 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
intent,
|
intent,
|
||||||
llm,
|
llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -47,9 +49,11 @@ async def test_default_prompt(
|
|||||||
mock_init_component,
|
mock_init_component,
|
||||||
area_registry: ar.AreaRegistry,
|
area_registry: ar.AreaRegistry,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
agent_id: str | None,
|
agent_id: str | None,
|
||||||
config_entry_options: {},
|
config_entry_options: {},
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the default prompt works."""
|
"""Test that the default prompt works."""
|
||||||
entry = MockConfigEntry(title=None)
|
entry = MockConfigEntry(title=None)
|
||||||
@ -64,7 +68,22 @@ async def test_default_prompt(
|
|||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
options={**mock_config_entry.options, **config_entry_options},
|
options={**mock_config_entry.options, **config_entry_options},
|
||||||
)
|
)
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
def create_entity(device: dr.DeviceEntry) -> None:
|
||||||
|
"""Create an entity for a device and track entity_id."""
|
||||||
|
entity = entity_registry.async_get_or_create(
|
||||||
|
"light",
|
||||||
|
"test",
|
||||||
|
device.id,
|
||||||
|
device_id=device.id,
|
||||||
|
original_name=str(device.name),
|
||||||
|
suggested_object_id=str(device.name),
|
||||||
|
)
|
||||||
|
entity.write_unavailable_state(hass)
|
||||||
|
entities.append(entity.entity_id)
|
||||||
|
|
||||||
|
create_entity(
|
||||||
device_registry.async_get_or_create(
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "1234")},
|
connections={("test", "1234")},
|
||||||
@ -73,7 +92,9 @@ async def test_default_prompt(
|
|||||||
model="Test Model",
|
model="Test Model",
|
||||||
suggested_area="Test Area",
|
suggested_area="Test Area",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
|
create_entity(
|
||||||
device_registry.async_get_or_create(
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", f"{i}abcd")},
|
connections={("test", f"{i}abcd")},
|
||||||
@ -83,6 +104,8 @@ async def test_default_prompt(
|
|||||||
suggested_area="Test Area",
|
suggested_area="Test Area",
|
||||||
entry_type=dr.DeviceEntryType.SERVICE,
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
device_registry.async_get_or_create(
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "5678")},
|
connections={("test", "5678")},
|
||||||
@ -91,6 +114,8 @@ async def test_default_prompt(
|
|||||||
model="Device 2",
|
model="Device 2",
|
||||||
suggested_area="Test Area 2",
|
suggested_area="Test Area 2",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
device_registry.async_get_or_create(
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "9876")},
|
connections={("test", "9876")},
|
||||||
@ -99,12 +124,15 @@ async def test_default_prompt(
|
|||||||
model="Test Model 3A",
|
model="Test Model 3A",
|
||||||
suggested_area="Test Area 2",
|
suggested_area="Test Area 2",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
device_registry.async_get_or_create(
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "qwer")},
|
connections={("test", "qwer")},
|
||||||
name="Test Device 4",
|
name="Test Device 4",
|
||||||
suggested_area="Test Area 2",
|
suggested_area="Test Area 2",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
device = device_registry.async_get_or_create(
|
device = device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "9876-disabled")},
|
connections={("test", "9876-disabled")},
|
||||||
@ -116,6 +144,8 @@ async def test_default_prompt(
|
|||||||
device_registry.async_update_device(
|
device_registry.async_update_device(
|
||||||
device.id, disabled_by=dr.DeviceEntryDisabler.USER
|
device.id, disabled_by=dr.DeviceEntryDisabler.USER
|
||||||
)
|
)
|
||||||
|
create_entity(device)
|
||||||
|
create_entity(
|
||||||
device_registry.async_get_or_create(
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "9876-no-name")},
|
connections={("test", "9876-no-name")},
|
||||||
@ -123,6 +153,8 @@ async def test_default_prompt(
|
|||||||
model="Test Model NoName",
|
model="Test Model NoName",
|
||||||
suggested_area="Test Area 2",
|
suggested_area="Test Area 2",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
device_registry.async_get_or_create(
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "9876-integer-values")},
|
connections={("test", "9876-integer-values")},
|
||||||
@ -131,6 +163,21 @@ async def test_default_prompt(
|
|||||||
model=3,
|
model=3,
|
||||||
suggested_area="Test Area 2",
|
suggested_area="Test Area 2",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set options for registered entities
|
||||||
|
ws_client = await hass_ws_client(hass)
|
||||||
|
await ws_client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "homeassistant/expose_entity",
|
||||||
|
"assistants": ["conversation"],
|
||||||
|
"entity_ids": entities,
|
||||||
|
"should_expose": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await ws_client.receive_json()
|
||||||
|
assert response["success"]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("google.generativeai.GenerativeModel") as mock_model,
|
patch("google.generativeai.GenerativeModel") as mock_model,
|
||||||
patch(
|
patch(
|
||||||
|
@ -11,10 +11,13 @@ from homeassistant.helpers import (
|
|||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
config_validation as cv,
|
config_validation as cv,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
floor_registry as fr,
|
floor_registry as fr,
|
||||||
intent,
|
intent,
|
||||||
llm,
|
llm,
|
||||||
)
|
)
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.util import yaml
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -158,10 +161,12 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
|
|||||||
async def test_assist_api_prompt(
|
async def test_assist_api_prompt(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
area_registry: ar.AreaRegistry,
|
area_registry: ar.AreaRegistry,
|
||||||
floor_registry: fr.FloorRegistry,
|
floor_registry: fr.FloorRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test prompt for the assist API."""
|
"""Test prompt for the assist API."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
context = Context()
|
context = Context()
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name=None,
|
tool_name=None,
|
||||||
@ -170,41 +175,232 @@ async def test_assist_api_prompt(
|
|||||||
context=context,
|
context=context,
|
||||||
user_prompt="test_text",
|
user_prompt="test_text",
|
||||||
language="*",
|
language="*",
|
||||||
assistant="test_assistant",
|
assistant="conversation",
|
||||||
device_id="test_device",
|
device_id="test_device",
|
||||||
)
|
)
|
||||||
api = llm.async_get_api(hass, "assist")
|
api = llm.async_get_api(hass, "assist")
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
prompt = await api.async_get_api_prompt(tool_input)
|
||||||
assert prompt == (
|
assert prompt == (
|
||||||
"Call the intent tools to control Home Assistant."
|
"Only if the user wants to control a device, tell them to expose entities to their "
|
||||||
" Just pass the name to the intent."
|
"voice assistant in Home Assistant."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Expose entities
|
||||||
entry = MockConfigEntry(title=None)
|
entry = MockConfigEntry(title=None)
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
tool_input.device_id = device_registry.async_get_or_create(
|
device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "1234")},
|
||||||
|
suggested_area="Test Area",
|
||||||
|
)
|
||||||
|
area = area_registry.async_get_area_by_name("Test Area")
|
||||||
|
area_registry.async_update(area.id, aliases=["Alternative name"])
|
||||||
|
entry1 = entity_registry.async_get_or_create(
|
||||||
|
"light",
|
||||||
|
"kitchen",
|
||||||
|
"mock-id-kitchen",
|
||||||
|
original_name="Kitchen",
|
||||||
|
suggested_object_id="kitchen",
|
||||||
|
)
|
||||||
|
entry2 = entity_registry.async_get_or_create(
|
||||||
|
"light",
|
||||||
|
"living_room",
|
||||||
|
"mock-id-living-room",
|
||||||
|
original_name="Living Room",
|
||||||
|
suggested_object_id="living_room",
|
||||||
|
device_id=device.id,
|
||||||
|
)
|
||||||
|
hass.states.async_set(entry1.entity_id, "on", {"friendly_name": "Kitchen"})
|
||||||
|
hass.states.async_set(entry2.entity_id, "on", {"friendly_name": "Living Room"})
|
||||||
|
|
||||||
|
def create_entity(device: dr.DeviceEntry, write_state=True) -> None:
|
||||||
|
"""Create an entity for a device and track entity_id."""
|
||||||
|
entity = entity_registry.async_get_or_create(
|
||||||
|
"light",
|
||||||
|
"test",
|
||||||
|
device.id,
|
||||||
|
device_id=device.id,
|
||||||
|
original_name=str(device.name or "Unnamed Device"),
|
||||||
|
suggested_object_id=str(device.name or "unnamed_device"),
|
||||||
|
)
|
||||||
|
if write_state:
|
||||||
|
entity.write_unavailable_state(hass)
|
||||||
|
|
||||||
|
create_entity(
|
||||||
|
device_registry.async_get_or_create(
|
||||||
config_entry_id=entry.entry_id,
|
config_entry_id=entry.entry_id,
|
||||||
connections={("test", "1234")},
|
connections={("test", "1234")},
|
||||||
name="Test Device",
|
name="Test Device",
|
||||||
manufacturer="Test Manufacturer",
|
manufacturer="Test Manufacturer",
|
||||||
model="Test Model",
|
model="Test Model",
|
||||||
suggested_area="Test Area",
|
suggested_area="Test Area",
|
||||||
).id
|
)
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
)
|
||||||
assert prompt == (
|
for i in range(3):
|
||||||
"Call the intent tools to control Home Assistant."
|
create_entity(
|
||||||
" Just pass the name to the intent. You are in Test Area."
|
device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", f"{i}abcd")},
|
||||||
|
name="Test Service",
|
||||||
|
manufacturer="Test Manufacturer",
|
||||||
|
model="Test Model",
|
||||||
|
suggested_area="Test Area",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
|
device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "5678")},
|
||||||
|
name="Test Device 2",
|
||||||
|
manufacturer="Test Manufacturer 2",
|
||||||
|
model="Device 2",
|
||||||
|
suggested_area="Test Area 2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
|
device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "9876")},
|
||||||
|
name="Test Device 3",
|
||||||
|
manufacturer="Test Manufacturer 3",
|
||||||
|
model="Test Model 3A",
|
||||||
|
suggested_area="Test Area 2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
|
device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "qwer")},
|
||||||
|
name="Test Device 4",
|
||||||
|
suggested_area="Test Area 2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device2 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "9876-disabled")},
|
||||||
|
name="Test Device 3 - disabled",
|
||||||
|
manufacturer="Test Manufacturer 3",
|
||||||
|
model="Test Model 3A",
|
||||||
|
suggested_area="Test Area 2",
|
||||||
|
)
|
||||||
|
device_registry.async_update_device(
|
||||||
|
device2.id, disabled_by=dr.DeviceEntryDisabler.USER
|
||||||
|
)
|
||||||
|
create_entity(device2, False)
|
||||||
|
create_entity(
|
||||||
|
device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "9876-no-name")},
|
||||||
|
manufacturer="Test Manufacturer NoName",
|
||||||
|
model="Test Model NoName",
|
||||||
|
suggested_area="Test Area 2",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
create_entity(
|
||||||
|
device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections={("test", "9876-integer-values")},
|
||||||
|
name=1,
|
||||||
|
manufacturer=2,
|
||||||
|
model=3,
|
||||||
|
suggested_area="Test Area 2",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant)
|
||||||
|
assert exposed_entities == {
|
||||||
|
"light.1": {
|
||||||
|
"areas": "Test Area 2",
|
||||||
|
"names": "1",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
entry1.entity_id: {
|
||||||
|
"names": "Kitchen",
|
||||||
|
"state": "on",
|
||||||
|
},
|
||||||
|
entry2.entity_id: {
|
||||||
|
"areas": "Test Area, Alternative name",
|
||||||
|
"names": "Living Room",
|
||||||
|
"state": "on",
|
||||||
|
},
|
||||||
|
"light.test_device": {
|
||||||
|
"areas": "Test Area, Alternative name",
|
||||||
|
"names": "Test Device",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
"light.test_device_2": {
|
||||||
|
"areas": "Test Area 2",
|
||||||
|
"names": "Test Device 2",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
"light.test_device_3": {
|
||||||
|
"areas": "Test Area 2",
|
||||||
|
"names": "Test Device 3",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
"light.test_device_4": {
|
||||||
|
"areas": "Test Area 2",
|
||||||
|
"names": "Test Device 4",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
"light.test_service": {
|
||||||
|
"areas": "Test Area, Alternative name",
|
||||||
|
"names": "Test Service",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
"light.test_service_2": {
|
||||||
|
"areas": "Test Area, Alternative name",
|
||||||
|
"names": "Test Service",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
"light.test_service_3": {
|
||||||
|
"areas": "Test Area, Alternative name",
|
||||||
|
"names": "Test Service",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
"light.unnamed_device": {
|
||||||
|
"areas": "Test Area 2",
|
||||||
|
"names": "Unnamed Device",
|
||||||
|
"state": "unavailable",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
exposed_entities_prompt = (
|
||||||
|
"An overview of the areas and the devices in this smart home:\n"
|
||||||
|
+ yaml.dump(exposed_entities)
|
||||||
|
)
|
||||||
|
first_part_prompt = (
|
||||||
|
"Call the intent tools to control Home Assistant. "
|
||||||
|
"Just pass the name to the intent. "
|
||||||
|
"When controlling an area, prefer passing area name."
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = await api.async_get_api_prompt(tool_input)
|
||||||
|
assert prompt == (
|
||||||
|
f"""{first_part_prompt}
|
||||||
|
{exposed_entities_prompt}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fake that request is made from a specific device ID
|
||||||
|
tool_input.device_id = device.id
|
||||||
|
prompt = await api.async_get_api_prompt(tool_input)
|
||||||
|
assert prompt == (
|
||||||
|
f"""{first_part_prompt}
|
||||||
|
You are in Test Area.
|
||||||
|
{exposed_entities_prompt}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add floor
|
||||||
floor = floor_registry.async_create("second floor")
|
floor = floor_registry.async_create("second floor")
|
||||||
area = area_registry.async_get_area_by_name("Test Area")
|
|
||||||
area_registry.async_update(area.id, floor_id=floor.floor_id)
|
area_registry.async_update(area.id, floor_id=floor.floor_id)
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
prompt = await api.async_get_api_prompt(tool_input)
|
||||||
assert prompt == (
|
assert prompt == (
|
||||||
"Call the intent tools to control Home Assistant."
|
f"""{first_part_prompt}
|
||||||
" Just pass the name to the intent. You are in Test Area (second floor)."
|
You are in Test Area (second floor).
|
||||||
|
{exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add user
|
||||||
context.user_id = "12345"
|
context.user_id = "12345"
|
||||||
mock_user = Mock()
|
mock_user = Mock()
|
||||||
mock_user.id = "12345"
|
mock_user.id = "12345"
|
||||||
@ -212,7 +408,8 @@ async def test_assist_api_prompt(
|
|||||||
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
|
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
|
||||||
prompt = await api.async_get_api_prompt(tool_input)
|
prompt = await api.async_get_api_prompt(tool_input)
|
||||||
assert prompt == (
|
assert prompt == (
|
||||||
"Call the intent tools to control Home Assistant."
|
f"""{first_part_prompt}
|
||||||
" Just pass the name to the intent. You are in Test Area (second floor)."
|
You are in Test Area (second floor).
|
||||||
" The user name is Test User."
|
The user name is Test User.
|
||||||
|
{exposed_entities_prompt}"""
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user