Update ollama to allow selecting mutiple LLM APIs (#142445)

* Update ollama to allow selecting mutiple LLM APIs

* Update homeassistant/helpers/llm.py

* Avoid gather since these don't do I/O

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter 2025-04-13 15:37:46 -07:00 committed by GitHub
parent 8b88272bc0
commit d91528648f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 219 additions and 18 deletions

View File

@ -387,7 +387,7 @@ class ChatLog:
self,
conversing_domain: str,
user_input: ConversationInput,
user_llm_hass_api: str | None = None,
user_llm_hass_api: str | list[str] | None = None,
user_llm_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""

View File

@ -215,8 +215,6 @@ class OllamaOptionsFlow(OptionsFlow):
) -> ConfigFlowResult:
"""Manage the options."""
if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(
title=_get_title(self.model), data=user_input
)
@ -234,18 +232,12 @@ def ollama_config_option_schema(
) -> dict:
"""Ollama options schema."""
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
hass_apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)
]
return {
vol.Optional(
@ -259,8 +251,7 @@ def ollama_config_option_schema(
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
vol.Optional(
CONF_NUM_CTX,
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},

View File

@ -110,15 +110,29 @@ def async_register_api(hass: HomeAssistant, api: API) -> Callable[[], None]:
async def async_get_api(
hass: HomeAssistant, api_id: str, llm_context: LLMContext
hass: HomeAssistant, api_id: str | list[str], llm_context: LLMContext
) -> APIInstance:
"""Get an API."""
"""Get an API.
This returns a single APIInstance for one or more API ids, merging into
a single instance of necessary.
"""
apis = _async_get_apis(hass)
if api_id not in apis:
raise HomeAssistantError(f"API {api_id} not found")
if isinstance(api_id, str):
api_id = [api_id]
return await apis[api_id].async_get_api_instance(llm_context)
for key in api_id:
if key not in apis:
raise HomeAssistantError(f"API {key} not found")
api: API
if len(api_id) == 1:
api = apis[api_id[0]]
else:
api = MergedAPI([apis[key] for key in api_id])
return await api.async_get_api_instance(llm_context)
@callback
@ -286,6 +300,102 @@ class IntentTool(Tool):
return response
class NamespacedTool(Tool):
"""A tool that wraps another tool, prepending a namespace.
This is used to support tools from multiple API. This tool dispatches
the original tool with the original non-namespaced name.
"""
def __init__(self, namespace: str, tool: Tool) -> None:
"""Init the class."""
self.namespace = namespace
self.name = f"{namespace}.{tool.name}"
self.description = tool.description
self.parameters = tool.parameters
self.tool = tool
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Handle the intent."""
return await self.tool.async_call(
hass,
ToolInput(
tool_name=self.tool.name,
tool_args=tool_input.tool_args,
id=tool_input.id,
),
llm_context,
)
class MergedAPI(API):
"""An API that represents a merged view of multiple APIs."""
def __init__(self, llm_apis: list[API]) -> None:
"""Init the class."""
if not llm_apis:
raise ValueError("No APIs provided")
hass = llm_apis[0].hass
api_ids = [unicode_slug.slugify(api.id) for api in llm_apis]
if len(set(api_ids)) != len(api_ids):
raise ValueError("API IDs must be unique")
super().__init__(
hass=hass,
id="|".join(unicode_slug.slugify(api.id) for api in llm_apis),
name="Merged LLM API",
)
self.llm_apis = llm_apis
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API."""
# These usually don't do I/O and execute right away
llm_apis = [
await llm_api.async_get_api_instance(llm_context)
for llm_api in self.llm_apis
]
prompt_parts = []
tools: list[Tool] = []
for api_instance in llm_apis:
namespace = unicode_slug.slugify(api_instance.api.name)
prompt_parts.append(
f'Follow these instructions for tools from "{namespace}":\n'
)
prompt_parts.append(api_instance.api_prompt)
prompt_parts.append("\n\n")
tools.extend(
[NamespacedTool(namespace, tool) for tool in api_instance.tools]
)
return APIInstance(
api=self,
api_prompt="".join(prompt_parts),
llm_context=llm_context,
tools=tools,
custom_serializer=self._custom_serializer(llm_apis),
)
def _custom_serializer(
self, llm_apis: list[APIInstance]
) -> Callable[[Any], Any] | None:
serializers = [
api_instance.custom_serializer
for api_instance in llm_apis
if api_instance.custom_serializer is not None
]
if not serializers:
return None
def merged(x: Any) -> Any:
for serializer in serializers:
if (result := serializer(x)) is not None:
return result
return x
return merged
class AssistAPI(API):
"""API exposing Assist API to LLMs."""

View File

@ -139,6 +139,48 @@ async def test_unknown_llm_api(
assert exc_info.value.as_conversation_result().as_dict() == snapshot
async def test_multiple_llm_apis(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test when we reference an LLM API."""
class MyTool(llm.Tool):
"""Test tool."""
name = "test_tool"
description = "Test function"
parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
class MyAPI(llm.API):
"""Test API."""
async def async_get_api_instance(
self, llm_context: llm.LLMContext
) -> llm.APIInstance:
"""Return a list of tools."""
return llm.APIInstance(self, "My API Prompt", llm_context, [MyTool()])
api = MyAPI(hass=hass, id="my-api", name="Test")
llm.async_register_api(hass, api)
with (
chat_session.async_get_chat_session(hass) as session,
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
):
await chat_log.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=["assist", "my-api"],
user_llm_prompt=None,
)
assert chat_log.llm_api
assert chat_log.llm_api.api.id == "assist|my-api"
async def test_template_error(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,

View File

@ -25,6 +25,7 @@ from homeassistant.helpers import (
)
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from homeassistant.util.json import JsonObjectType
from tests.common import MockConfigEntry, async_mock_service
@ -45,9 +46,12 @@ def llm_context() -> llm.LLMContext:
class MyAPI(llm.API):
"""Test API."""
prompt: str = ""
tools: list[llm.Tool] = []
async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
"""Return a list of tools."""
return llm.APIInstance(self, "", [], llm_context)
return llm.APIInstance(self, self.prompt, llm_context, self.tools)
async def test_get_api_no_existing(
@ -1326,3 +1330,57 @@ async def test_no_tools_exposed(hass: HomeAssistant) -> None:
)
api = await llm.async_get_api(hass, "assist", llm_context)
assert api.tools == []
async def test_merged_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
"""Test an API instance that merges multiple llm apis."""
class MyTool(llm.Tool):
def __init__(self, name: str, description: str) -> None:
self.name = name
self.description = description
async def async_call(
self, hass: HomeAssistant, tool_input: llm.ToolInput, _: llm.LLMContext
) -> JsonObjectType:
return {"result": {tool_input.tool_name: tool_input.tool_args}}
api1 = MyAPI(hass=hass, id="api-1", name="API 1")
api1.prompt = "This is prompt 1"
api1.tools = [MyTool(name="Tool_1", description="Description 1")]
llm.async_register_api(hass, api1)
api2 = MyAPI(hass=hass, id="api-2", name="API 2")
api2.prompt = "This is prompt 2"
api2.tools = [MyTool(name="Tool_2", description="Description 2")]
llm.async_register_api(hass, api2)
instance = await llm.async_get_api(hass, ["api-1", "api-2"], llm_context)
assert instance.api.id == "api-1|api-2"
assert (
instance.api_prompt
== """Follow these instructions for tools from "api-1":
This is prompt 1
Follow these instructions for tools from "api-2":
This is prompt 2
"""
)
assert [(tool.name, tool.description) for tool in instance.tools] == [
("api-1.Tool_1", "Description 1"),
("api-2.Tool_2", "Description 2"),
]
# The test tool returns back the provided arguments so we can verify
# the original tool is invoked with the correct tool name and args.
result = await instance.async_call_tool(
llm.ToolInput(tool_name="api-1.Tool_1", tool_args={"arg1": "value1"})
)
assert result == {"result": {"Tool_1": {"arg1": "value1"}}}
result = await instance.async_call_tool(
llm.ToolInput(tool_name="api-2.Tool_2", tool_args={"arg2": "value2"})
)
assert result == {"result": {"Tool_2": {"arg2": "value2"}}}