Add API class to LLM helper (#117707)

* Add API class to LLM helper

* Add more tests

* Rename intent to assist to broaden scope
This commit is contained in:
Paulus Schoutsen 2024-05-18 21:14:05 -04:00 committed by GitHub
parent bfc52b9fab
commit d001e7daea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 125 additions and 42 deletions

View File

@ -2,10 +2,8 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
import logging
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
@ -17,19 +15,53 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import JsonObjectType from homeassistant.util.json import JsonObjectType
from . import intent from . import intent
from .singleton import singleton
_LOGGER = logging.getLogger(__name__)
IGNORE_INTENTS = [ @singleton("llm")
intent.INTENT_NEVERMIND, @callback
intent.INTENT_GET_STATE, def _async_get_apis(hass: HomeAssistant) -> dict[str, API]:
INTENT_GET_WEATHER, """Get all the LLM APIs."""
INTENT_GET_TEMPERATURE, return {
] "assist": AssistAPI(
hass=hass,
id="assist",
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
),
}
@callback
def async_register_api(hass: HomeAssistant, api: API) -> None:
"""Register an API to be exposed to LLMs."""
apis = _async_get_apis(hass)
if api.id in apis:
raise HomeAssistantError(f"API {api.id} is already registered")
apis[api.id] = api
@callback
def async_get_api(hass: HomeAssistant, api_id: str) -> API:
"""Get an API."""
apis = _async_get_apis(hass)
if api_id not in apis:
raise HomeAssistantError(f"API {api_id} not found")
return apis[api_id]
@callback
def async_get_apis(hass: HomeAssistant) -> list[API]:
"""Get all the LLM APIs."""
return list(_async_get_apis(hass).values())
@dataclass(slots=True) @dataclass(slots=True)
class ToolInput: class ToolInput(ABC):
"""Tool input to be processed.""" """Tool input to be processed."""
tool_name: str tool_name: str
@ -60,34 +92,40 @@ class Tool:
return f"<{self.__class__.__name__} - {self.name}>" return f"<{self.__class__.__name__} - {self.name}>"
@callback @dataclass(slots=True, kw_only=True)
def async_get_tools(hass: HomeAssistant) -> Iterable[Tool]: class API(ABC):
"""Return a list of LLM tools.""" """An API to expose to LLMs."""
for intent_handler in intent.async_get(hass):
if intent_handler.intent_type not in IGNORE_INTENTS:
yield IntentTool(intent_handler)
hass: HomeAssistant
id: str
name: str
prompt_template: str
@callback @abstractmethod
async def async_call_tool(hass: HomeAssistant, tool_input: ToolInput) -> JsonObjectType: @callback
"""Call a LLM tool, validate args and return the response.""" def async_get_tools(self) -> list[Tool]:
for tool in async_get_tools(hass): """Return a list of tools."""
if tool.name == tool_input.tool_name: raise NotImplementedError
break
else:
raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
_tool_input = ToolInput( async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
tool_name=tool.name, """Call a LLM tool, validate args and return the response."""
tool_args=tool.parameters(tool_input.tool_args), for tool in self.async_get_tools():
platform=tool_input.platform, if tool.name == tool_input.tool_name:
context=tool_input.context or Context(), break
user_prompt=tool_input.user_prompt, else:
language=tool_input.language, raise HomeAssistantError(f'Tool "{tool_input.tool_name}" not found')
assistant=tool_input.assistant,
)
return await tool.async_call(hass, _tool_input) _tool_input = ToolInput(
tool_name=tool.name,
tool_args=tool.parameters(tool_input.tool_args),
platform=tool_input.platform,
context=tool_input.context or Context(),
user_prompt=tool_input.user_prompt,
language=tool_input.language,
assistant=tool_input.assistant,
)
return await tool.async_call(self.hass, _tool_input)
class IntentTool(Tool): class IntentTool(Tool):
@ -120,3 +158,23 @@ class IntentTool(Tool):
tool_input.assistant, tool_input.assistant,
) )
return intent_response.as_dict() return intent_response.as_dict()
class AssistAPI(API):
"""API exposing Assist API to LLMs."""
IGNORE_INTENTS = {
intent.INTENT_NEVERMIND,
intent.INTENT_GET_STATE,
INTENT_GET_WEATHER,
INTENT_GET_TEMPERATURE,
}
@callback
def async_get_tools(self) -> list[Tool]:
"""Return a list of LLM tools."""
return [
IntentTool(intent_handler)
for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in self.IGNORE_INTENTS
]

View File

@ -10,11 +10,33 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, llm from homeassistant.helpers import config_validation as cv, intent, llm
async def test_get_api_no_existing(hass: HomeAssistant) -> None:
"""Test getting an llm api where no config exists."""
with pytest.raises(HomeAssistantError):
llm.async_get_api(hass, "non-existing")
async def test_register_api(hass: HomeAssistant) -> None:
"""Test registering an llm api."""
api = llm.AssistAPI(
hass=hass,
id="test",
name="Test",
prompt_template="Test",
)
llm.async_register_api(hass, api)
assert llm.async_get_api(hass, "test") is api
assert api in llm.async_get_apis(hass)
with pytest.raises(HomeAssistantError):
llm.async_register_api(hass, api)
async def test_call_tool_no_existing(hass: HomeAssistant) -> None: async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
"""Test calling an llm tool where no config exists.""" """Test calling an llm tool where no config exists."""
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
await llm.async_call_tool( await llm.async_get_api(hass, "intent").async_call_tool(
hass,
llm.ToolInput( llm.ToolInput(
"test_tool", "test_tool",
{}, {},
@ -27,8 +49,8 @@ async def test_call_tool_no_existing(hass: HomeAssistant) -> None:
) )
async def test_intent_tool(hass: HomeAssistant) -> None: async def test_assist_api(hass: HomeAssistant) -> None:
"""Test IntentTool class.""" """Test Assist API."""
schema = { schema = {
vol.Optional("area"): cv.string, vol.Optional("area"): cv.string,
vol.Optional("floor"): cv.string, vol.Optional("floor"): cv.string,
@ -42,8 +64,11 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
intent.async_register(hass, intent_handler) intent.async_register(hass, intent_handler)
assert len(list(llm.async_get_tools(hass))) == 1 assert len(llm.async_get_apis(hass)) == 1
tool = list(llm.async_get_tools(hass))[0] api = llm.async_get_api(hass, "assist")
tools = api.async_get_tools()
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_intent" assert tool.name == "test_intent"
assert tool.description == "Execute Home Assistant test_intent intent" assert tool.description == "Execute Home Assistant test_intent intent"
assert tool.parameters == vol.Schema(intent_handler.slot_schema) assert tool.parameters == vol.Schema(intent_handler.slot_schema)
@ -66,7 +91,7 @@ async def test_intent_tool(hass: HomeAssistant) -> None:
with patch( with patch(
"homeassistant.helpers.intent.async_handle", return_value=intent_response "homeassistant.helpers.intent.async_handle", return_value=intent_response
) as mock_intent_handle: ) as mock_intent_handle:
response = await llm.async_call_tool(hass, tool_input) response = await api.async_call_tool(tool_input)
mock_intent_handle.assert_awaited_once_with( mock_intent_handle.assert_awaited_once_with(
hass, hass,