mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Update ollama to allow selecting mutiple LLM APIs (#142445)
* Update ollama to allow selecting mutiple LLM APIs * Update homeassistant/helpers/llm.py * Avoid gather since these don't do I/O --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
8b88272bc0
commit
d91528648f
@ -387,7 +387,7 @@ class ChatLog:
|
||||
self,
|
||||
conversing_domain: str,
|
||||
user_input: ConversationInput,
|
||||
user_llm_hass_api: str | None = None,
|
||||
user_llm_hass_api: str | list[str] | None = None,
|
||||
user_llm_prompt: str | None = None,
|
||||
) -> None:
|
||||
"""Set the LLM system prompt."""
|
||||
|
@ -215,8 +215,6 @@ class OllamaOptionsFlow(OptionsFlow):
|
||||
) -> ConfigFlowResult:
|
||||
"""Manage the options."""
|
||||
if user_input is not None:
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
return self.async_create_entry(
|
||||
title=_get_title(self.model), data=user_input
|
||||
)
|
||||
@ -234,18 +232,12 @@ def ollama_config_option_schema(
|
||||
) -> dict:
|
||||
"""Ollama options schema."""
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label="No control",
|
||||
value="none",
|
||||
)
|
||||
]
|
||||
hass_apis.extend(
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
vol.Optional(
|
||||
@ -259,8 +251,7 @@ def ollama_config_option_schema(
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
||||
vol.Optional(
|
||||
CONF_NUM_CTX,
|
||||
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
||||
|
@ -110,15 +110,29 @@ def async_register_api(hass: HomeAssistant, api: API) -> Callable[[], None]:
|
||||
|
||||
|
||||
async def async_get_api(
|
||||
hass: HomeAssistant, api_id: str, llm_context: LLMContext
|
||||
hass: HomeAssistant, api_id: str | list[str], llm_context: LLMContext
|
||||
) -> APIInstance:
|
||||
"""Get an API."""
|
||||
"""Get an API.
|
||||
|
||||
This returns a single APIInstance for one or more API ids, merging into
|
||||
a single instance of necessary.
|
||||
"""
|
||||
apis = _async_get_apis(hass)
|
||||
|
||||
if api_id not in apis:
|
||||
raise HomeAssistantError(f"API {api_id} not found")
|
||||
if isinstance(api_id, str):
|
||||
api_id = [api_id]
|
||||
|
||||
return await apis[api_id].async_get_api_instance(llm_context)
|
||||
for key in api_id:
|
||||
if key not in apis:
|
||||
raise HomeAssistantError(f"API {key} not found")
|
||||
|
||||
api: API
|
||||
if len(api_id) == 1:
|
||||
api = apis[api_id[0]]
|
||||
else:
|
||||
api = MergedAPI([apis[key] for key in api_id])
|
||||
|
||||
return await api.async_get_api_instance(llm_context)
|
||||
|
||||
|
||||
@callback
|
||||
@ -286,6 +300,102 @@ class IntentTool(Tool):
|
||||
return response
|
||||
|
||||
|
||||
class NamespacedTool(Tool):
|
||||
"""A tool that wraps another tool, prepending a namespace.
|
||||
|
||||
This is used to support tools from multiple API. This tool dispatches
|
||||
the original tool with the original non-namespaced name.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, tool: Tool) -> None:
|
||||
"""Init the class."""
|
||||
self.namespace = namespace
|
||||
self.name = f"{namespace}.{tool.name}"
|
||||
self.description = tool.description
|
||||
self.parameters = tool.parameters
|
||||
self.tool = tool
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Handle the intent."""
|
||||
return await self.tool.async_call(
|
||||
hass,
|
||||
ToolInput(
|
||||
tool_name=self.tool.name,
|
||||
tool_args=tool_input.tool_args,
|
||||
id=tool_input.id,
|
||||
),
|
||||
llm_context,
|
||||
)
|
||||
|
||||
|
||||
class MergedAPI(API):
|
||||
"""An API that represents a merged view of multiple APIs."""
|
||||
|
||||
def __init__(self, llm_apis: list[API]) -> None:
|
||||
"""Init the class."""
|
||||
if not llm_apis:
|
||||
raise ValueError("No APIs provided")
|
||||
hass = llm_apis[0].hass
|
||||
api_ids = [unicode_slug.slugify(api.id) for api in llm_apis]
|
||||
if len(set(api_ids)) != len(api_ids):
|
||||
raise ValueError("API IDs must be unique")
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
id="|".join(unicode_slug.slugify(api.id) for api in llm_apis),
|
||||
name="Merged LLM API",
|
||||
)
|
||||
self.llm_apis = llm_apis
|
||||
|
||||
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
|
||||
"""Return the instance of the API."""
|
||||
# These usually don't do I/O and execute right away
|
||||
llm_apis = [
|
||||
await llm_api.async_get_api_instance(llm_context)
|
||||
for llm_api in self.llm_apis
|
||||
]
|
||||
prompt_parts = []
|
||||
tools: list[Tool] = []
|
||||
for api_instance in llm_apis:
|
||||
namespace = unicode_slug.slugify(api_instance.api.name)
|
||||
prompt_parts.append(
|
||||
f'Follow these instructions for tools from "{namespace}":\n'
|
||||
)
|
||||
prompt_parts.append(api_instance.api_prompt)
|
||||
prompt_parts.append("\n\n")
|
||||
tools.extend(
|
||||
[NamespacedTool(namespace, tool) for tool in api_instance.tools]
|
||||
)
|
||||
|
||||
return APIInstance(
|
||||
api=self,
|
||||
api_prompt="".join(prompt_parts),
|
||||
llm_context=llm_context,
|
||||
tools=tools,
|
||||
custom_serializer=self._custom_serializer(llm_apis),
|
||||
)
|
||||
|
||||
def _custom_serializer(
|
||||
self, llm_apis: list[APIInstance]
|
||||
) -> Callable[[Any], Any] | None:
|
||||
serializers = [
|
||||
api_instance.custom_serializer
|
||||
for api_instance in llm_apis
|
||||
if api_instance.custom_serializer is not None
|
||||
]
|
||||
if not serializers:
|
||||
return None
|
||||
|
||||
def merged(x: Any) -> Any:
|
||||
for serializer in serializers:
|
||||
if (result := serializer(x)) is not None:
|
||||
return result
|
||||
return x
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
class AssistAPI(API):
|
||||
"""API exposing Assist API to LLMs."""
|
||||
|
||||
|
@ -139,6 +139,48 @@ async def test_unknown_llm_api(
|
||||
assert exc_info.value.as_conversation_result().as_dict() == snapshot
|
||||
|
||||
|
||||
async def test_multiple_llm_apis(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
) -> None:
|
||||
"""Test when we reference an LLM API."""
|
||||
|
||||
class MyTool(llm.Tool):
|
||||
"""Test tool."""
|
||||
|
||||
name = "test_tool"
|
||||
description = "Test function"
|
||||
parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
|
||||
class MyAPI(llm.API):
|
||||
"""Test API."""
|
||||
|
||||
async def async_get_api_instance(
|
||||
self, llm_context: llm.LLMContext
|
||||
) -> llm.APIInstance:
|
||||
"""Return a list of tools."""
|
||||
return llm.APIInstance(self, "My API Prompt", llm_context, [MyTool()])
|
||||
|
||||
api = MyAPI(hass=hass, id="my-api", name="Test")
|
||||
llm.async_register_api(hass, api)
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass) as session,
|
||||
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||
):
|
||||
await chat_log.async_update_llm_data(
|
||||
conversing_domain="test",
|
||||
user_input=mock_conversation_input,
|
||||
user_llm_hass_api=["assist", "my-api"],
|
||||
user_llm_prompt=None,
|
||||
)
|
||||
|
||||
assert chat_log.llm_api
|
||||
assert chat_log.llm_api.api.id == "assist|my-api"
|
||||
|
||||
|
||||
async def test_template_error(
|
||||
hass: HomeAssistant,
|
||||
mock_conversation_input: ConversationInput,
|
||||
|
@ -25,6 +25,7 @@ from homeassistant.helpers import (
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from tests.common import MockConfigEntry, async_mock_service
|
||||
|
||||
@ -45,9 +46,12 @@ def llm_context() -> llm.LLMContext:
|
||||
class MyAPI(llm.API):
|
||||
"""Test API."""
|
||||
|
||||
prompt: str = ""
|
||||
tools: list[llm.Tool] = []
|
||||
|
||||
async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
|
||||
"""Return a list of tools."""
|
||||
return llm.APIInstance(self, "", [], llm_context)
|
||||
return llm.APIInstance(self, self.prompt, llm_context, self.tools)
|
||||
|
||||
|
||||
async def test_get_api_no_existing(
|
||||
@ -1326,3 +1330,57 @@ async def test_no_tools_exposed(hass: HomeAssistant) -> None:
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
assert api.tools == []
|
||||
|
||||
|
||||
async def test_merged_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
|
||||
"""Test an API instance that merges multiple llm apis."""
|
||||
|
||||
class MyTool(llm.Tool):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: llm.ToolInput, _: llm.LLMContext
|
||||
) -> JsonObjectType:
|
||||
return {"result": {tool_input.tool_name: tool_input.tool_args}}
|
||||
|
||||
api1 = MyAPI(hass=hass, id="api-1", name="API 1")
|
||||
api1.prompt = "This is prompt 1"
|
||||
api1.tools = [MyTool(name="Tool_1", description="Description 1")]
|
||||
llm.async_register_api(hass, api1)
|
||||
|
||||
api2 = MyAPI(hass=hass, id="api-2", name="API 2")
|
||||
api2.prompt = "This is prompt 2"
|
||||
api2.tools = [MyTool(name="Tool_2", description="Description 2")]
|
||||
llm.async_register_api(hass, api2)
|
||||
|
||||
instance = await llm.async_get_api(hass, ["api-1", "api-2"], llm_context)
|
||||
assert instance.api.id == "api-1|api-2"
|
||||
|
||||
assert (
|
||||
instance.api_prompt
|
||||
== """Follow these instructions for tools from "api-1":
|
||||
This is prompt 1
|
||||
|
||||
Follow these instructions for tools from "api-2":
|
||||
This is prompt 2
|
||||
|
||||
"""
|
||||
)
|
||||
assert [(tool.name, tool.description) for tool in instance.tools] == [
|
||||
("api-1.Tool_1", "Description 1"),
|
||||
("api-2.Tool_2", "Description 2"),
|
||||
]
|
||||
|
||||
# The test tool returns back the provided arguments so we can verify
|
||||
# the original tool is invoked with the correct tool name and args.
|
||||
result = await instance.async_call_tool(
|
||||
llm.ToolInput(tool_name="api-1.Tool_1", tool_args={"arg1": "value1"})
|
||||
)
|
||||
assert result == {"result": {"Tool_1": {"arg1": "value1"}}}
|
||||
|
||||
result = await instance.async_call_tool(
|
||||
llm.ToolInput(tool_name="api-2.Tool_2", tool_args={"arg2": "value2"})
|
||||
)
|
||||
assert result == {"result": {"Tool_2": {"arg2": "value2"}}}
|
||||
|
Loading…
x
Reference in New Issue
Block a user