mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 02:49:40 +00:00
LLM Assist API to ignore intents if not needed for exposed entities or calling device (#118283)
* LLM Assist API to ignore timer intents if device doesn't support it * Refactor to use API instances * Extract ToolContext class * Limit exposed intents based on exposed entities
This commit is contained in:
@@ -5,6 +5,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.core import Context, HomeAssistant, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
@@ -22,53 +23,84 @@ from homeassistant.util import yaml
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_get_api_no_existing(hass: HomeAssistant) -> None:
|
||||
@pytest.fixture
|
||||
def tool_input_context() -> llm.ToolContext:
|
||||
"""Return tool input context."""
|
||||
return llm.ToolContext(
|
||||
platform="",
|
||||
context=None,
|
||||
user_prompt=None,
|
||||
language=None,
|
||||
assistant=None,
|
||||
device_id=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_get_api_no_existing(
|
||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||
) -> None:
|
||||
"""Test getting an llm api where no config exists."""
|
||||
with pytest.raises(HomeAssistantError):
|
||||
llm.async_get_api(hass, "non-existing")
|
||||
await llm.async_get_api(hass, "non-existing", tool_input_context)
|
||||
|
||||
|
||||
async def test_register_api(hass: HomeAssistant) -> None:
|
||||
async def test_register_api(
|
||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||
) -> None:
|
||||
"""Test registering an llm api."""
|
||||
|
||||
class MyAPI(llm.API):
|
||||
async def async_get_api_prompt(self, tool_input: llm.ToolInput) -> str:
|
||||
"""Return a prompt for the tool."""
|
||||
return ""
|
||||
|
||||
def async_get_tools(self) -> list[llm.Tool]:
|
||||
async def async_get_api_instance(
|
||||
self, tool_input: llm.ToolInput
|
||||
) -> llm.APIInstance:
|
||||
"""Return a list of tools."""
|
||||
return []
|
||||
return llm.APIInstance(self, "", [], tool_input_context)
|
||||
|
||||
api = MyAPI(hass=hass, id="test", name="Test")
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
assert llm.async_get_api(hass, "test") is api
|
||||
instance = await llm.async_get_api(hass, "test", tool_input_context)
|
||||
assert instance.api is api
|
||||
assert api in llm.async_get_apis(hass)
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
|
||||
async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
|
||||
async def test_call_tool_no_existing(
|
||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||
) -> None:
|
||||
"""Test calling an llm tool where no config exists."""
|
||||
instance = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await llm.async_get_api(hass, "intent").async_call_tool(
|
||||
llm.ToolInput(
|
||||
"test_tool",
|
||||
{},
|
||||
"test_platform",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
await instance.async_call_tool(
|
||||
llm.ToolInput("test_tool", {}),
|
||||
)
|
||||
|
||||
|
||||
async def test_assist_api(hass: HomeAssistant) -> None:
|
||||
async def test_assist_api(
|
||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||
) -> None:
|
||||
"""Test Assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
entity_registry.async_get_or_create(
|
||||
"light",
|
||||
"kitchen",
|
||||
"mock-id-kitchen",
|
||||
original_name="Kitchen",
|
||||
suggested_object_id="kitchen",
|
||||
).write_unavailable_state(hass)
|
||||
|
||||
test_context = Context()
|
||||
tool_context = llm.ToolContext(
|
||||
platform="test_platform",
|
||||
context=test_context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id="test_device",
|
||||
)
|
||||
schema = {
|
||||
vol.Optional("area"): cv.string,
|
||||
vol.Optional("floor"): cv.string,
|
||||
@@ -77,22 +109,33 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||
class MyIntentHandler(intent.IntentHandler):
|
||||
intent_type = "test_intent"
|
||||
slot_schema = schema
|
||||
platforms = set() # Match none
|
||||
|
||||
intent_handler = MyIntentHandler()
|
||||
|
||||
intent.async_register(hass, intent_handler)
|
||||
|
||||
assert len(llm.async_get_apis(hass)) == 1
|
||||
api = llm.async_get_api(hass, "assist")
|
||||
tools = api.async_get_tools()
|
||||
assert len(tools) == 1
|
||||
tool = tools[0]
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert len(api.tools) == 0
|
||||
|
||||
# Match all
|
||||
intent_handler.platforms = None
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert len(api.tools) == 1
|
||||
|
||||
# Match specific domain
|
||||
intent_handler.platforms = {"light"}
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert len(api.tools) == 1
|
||||
tool = api.tools[0]
|
||||
assert tool.name == "test_intent"
|
||||
assert tool.description == "Execute Home Assistant test_intent intent"
|
||||
assert tool.parameters == vol.Schema(intent_handler.slot_schema)
|
||||
assert str(tool) == "<IntentTool - test_intent>"
|
||||
|
||||
test_context = Context()
|
||||
assert test_context.json_fragment # To reproduce an error case in tracing
|
||||
intent_response = intent.IntentResponse("*")
|
||||
intent_response.matched_states = [State("light.matched", "on")]
|
||||
@@ -100,12 +143,6 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name="test_intent",
|
||||
tool_args={"area": "kitchen", "floor": "ground_floor"},
|
||||
platform="test_platform",
|
||||
context=test_context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="test_assistant",
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -114,18 +151,18 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||
response = await api.async_call_tool(tool_input)
|
||||
|
||||
mock_intent_handle.assert_awaited_once_with(
|
||||
hass,
|
||||
"test_platform",
|
||||
"test_intent",
|
||||
{
|
||||
hass=hass,
|
||||
platform="test_platform",
|
||||
intent_type="test_intent",
|
||||
slots={
|
||||
"area": {"value": "kitchen"},
|
||||
"floor": {"value": "ground_floor"},
|
||||
},
|
||||
"test_text",
|
||||
test_context,
|
||||
"*",
|
||||
"test_assistant",
|
||||
"test_device",
|
||||
text_input="test_text",
|
||||
context=test_context,
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id="test_device",
|
||||
)
|
||||
assert response == {
|
||||
"card": {},
|
||||
@@ -140,7 +177,27 @@ async def test_assist_api(hass: HomeAssistant) -> None:
|
||||
}
|
||||
|
||||
|
||||
async def test_assist_api_description(hass: HomeAssistant) -> None:
|
||||
async def test_assist_api_get_timer_tools(
|
||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||
) -> None:
|
||||
"""Test getting timer tools with Assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||
|
||||
assert "HassStartTimer" not in [tool.name for tool in api.tools]
|
||||
|
||||
tool_input_context.device_id = "test_device"
|
||||
|
||||
async_register_timer_handler(hass, "test_device", lambda *args: None)
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||
assert "HassStartTimer" in [tool.name for tool in api.tools]
|
||||
|
||||
|
||||
async def test_assist_api_description(
|
||||
hass: HomeAssistant, tool_input_context: llm.ToolContext
|
||||
) -> None:
|
||||
"""Test intent description with Assist API."""
|
||||
|
||||
class MyIntentHandler(intent.IntentHandler):
|
||||
@@ -150,10 +207,9 @@ async def test_assist_api_description(hass: HomeAssistant) -> None:
|
||||
intent.async_register(hass, MyIntentHandler())
|
||||
|
||||
assert len(llm.async_get_apis(hass)) == 1
|
||||
api = llm.async_get_api(hass, "assist")
|
||||
tools = api.async_get_tools()
|
||||
assert len(tools) == 1
|
||||
tool = tools[0]
|
||||
api = await llm.async_get_api(hass, "assist", tool_input_context)
|
||||
assert len(api.tools) == 1
|
||||
tool = api.tools[0]
|
||||
assert tool.name == "test_intent"
|
||||
assert tool.description == "my intent handler"
|
||||
|
||||
@@ -167,20 +223,18 @@ async def test_assist_api_prompt(
|
||||
) -> None:
|
||||
"""Test prompt for the assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
context = Context()
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=None,
|
||||
tool_args=None,
|
||||
tool_context = llm.ToolContext(
|
||||
platform="test_platform",
|
||||
context=context,
|
||||
user_prompt="test_text",
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id="test_device",
|
||||
device_id=None,
|
||||
)
|
||||
api = llm.async_get_api(hass, "assist")
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert api.api_prompt == (
|
||||
"Only if the user wants to control a device, tell them to expose entities to their "
|
||||
"voice assistant in Home Assistant."
|
||||
)
|
||||
@@ -308,7 +362,7 @@ async def test_assist_api_prompt(
|
||||
)
|
||||
)
|
||||
|
||||
exposed_entities = llm._get_exposed_entities(hass, tool_input.assistant)
|
||||
exposed_entities = llm._get_exposed_entities(hass, tool_context.assistant)
|
||||
assert exposed_entities == {
|
||||
"light.1": {
|
||||
"areas": "Test Area 2",
|
||||
@@ -373,40 +427,55 @@ async def test_assist_api_prompt(
|
||||
"Call the intent tools to control Home Assistant. "
|
||||
"When controlling an area, prefer passing area name."
|
||||
)
|
||||
no_timer_prompt = "This device does not support timers."
|
||||
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
area_prompt = (
|
||||
"Reject all generic commands like 'turn on the lights' because we don't know in what area "
|
||||
"this conversation is happening."
|
||||
)
|
||||
assert prompt == (
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Fake that request is made from a specific device ID
|
||||
tool_input.device_id = device.id
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
# Fake that request is made from a specific device ID with an area
|
||||
tool_context.device_id = device.id
|
||||
area_prompt = (
|
||||
"You are in area Test Area and all generic commands like 'turn on the lights' "
|
||||
"should target this area."
|
||||
)
|
||||
assert prompt == (
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Add floor
|
||||
floor = floor_registry.async_create("2")
|
||||
area_registry.async_update(area.id, floor_id=floor.floor_id)
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
area_prompt = (
|
||||
"You are in area Test Area (floor 2) and all generic commands like 'turn on the lights' "
|
||||
"should target this area."
|
||||
)
|
||||
assert prompt == (
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
)
|
||||
|
||||
# Register device for timers
|
||||
async_register_timer_handler(hass, device.id, lambda *args: None)
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
# The no_timer_prompt is gone
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
{exposed_entities_prompt}"""
|
||||
@@ -418,8 +487,8 @@ async def test_assist_api_prompt(
|
||||
mock_user.id = "12345"
|
||||
mock_user.name = "Test User"
|
||||
with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user):
|
||||
prompt = await api.async_get_api_prompt(tool_input)
|
||||
assert prompt == (
|
||||
api = await llm.async_get_api(hass, "assist", tool_context)
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{area_prompt}
|
||||
The user name is Test User.
|
||||
|
||||
Reference in New Issue
Block a user