mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Update ollama to allow selecting mutiple LLM APIs (#142445)
* Update ollama to allow selecting mutiple LLM APIs * Update homeassistant/helpers/llm.py * Avoid gather since these don't do I/O --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
8b88272bc0
commit
d91528648f
@ -387,7 +387,7 @@ class ChatLog:
|
|||||||
self,
|
self,
|
||||||
conversing_domain: str,
|
conversing_domain: str,
|
||||||
user_input: ConversationInput,
|
user_input: ConversationInput,
|
||||||
user_llm_hass_api: str | None = None,
|
user_llm_hass_api: str | list[str] | None = None,
|
||||||
user_llm_prompt: str | None = None,
|
user_llm_prompt: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the LLM system prompt."""
|
"""Set the LLM system prompt."""
|
||||||
|
@ -215,8 +215,6 @@ class OllamaOptionsFlow(OptionsFlow):
|
|||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Manage the options."""
|
"""Manage the options."""
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if user_input[CONF_LLM_HASS_API] == "none":
|
|
||||||
user_input.pop(CONF_LLM_HASS_API)
|
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=_get_title(self.model), data=user_input
|
title=_get_title(self.model), data=user_input
|
||||||
)
|
)
|
||||||
@ -234,18 +232,12 @@ def ollama_config_option_schema(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Ollama options schema."""
|
"""Ollama options schema."""
|
||||||
hass_apis: list[SelectOptionDict] = [
|
hass_apis: list[SelectOptionDict] = [
|
||||||
SelectOptionDict(
|
|
||||||
label="No control",
|
|
||||||
value="none",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
hass_apis.extend(
|
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label=api.name,
|
label=api.name,
|
||||||
value=api.id,
|
value=api.id,
|
||||||
)
|
)
|
||||||
for api in llm.async_get_apis(hass)
|
for api in llm.async_get_apis(hass)
|
||||||
)
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
@ -259,8 +251,7 @@ def ollama_config_option_schema(
|
|||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_LLM_HASS_API,
|
CONF_LLM_HASS_API,
|
||||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||||
default="none",
|
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
||||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_NUM_CTX,
|
CONF_NUM_CTX,
|
||||||
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
||||||
|
@ -110,15 +110,29 @@ def async_register_api(hass: HomeAssistant, api: API) -> Callable[[], None]:
|
|||||||
|
|
||||||
|
|
||||||
async def async_get_api(
|
async def async_get_api(
|
||||||
hass: HomeAssistant, api_id: str, llm_context: LLMContext
|
hass: HomeAssistant, api_id: str | list[str], llm_context: LLMContext
|
||||||
) -> APIInstance:
|
) -> APIInstance:
|
||||||
"""Get an API."""
|
"""Get an API.
|
||||||
|
|
||||||
|
This returns a single APIInstance for one or more API ids, merging into
|
||||||
|
a single instance of necessary.
|
||||||
|
"""
|
||||||
apis = _async_get_apis(hass)
|
apis = _async_get_apis(hass)
|
||||||
|
|
||||||
if api_id not in apis:
|
if isinstance(api_id, str):
|
||||||
raise HomeAssistantError(f"API {api_id} not found")
|
api_id = [api_id]
|
||||||
|
|
||||||
return await apis[api_id].async_get_api_instance(llm_context)
|
for key in api_id:
|
||||||
|
if key not in apis:
|
||||||
|
raise HomeAssistantError(f"API {key} not found")
|
||||||
|
|
||||||
|
api: API
|
||||||
|
if len(api_id) == 1:
|
||||||
|
api = apis[api_id[0]]
|
||||||
|
else:
|
||||||
|
api = MergedAPI([apis[key] for key in api_id])
|
||||||
|
|
||||||
|
return await api.async_get_api_instance(llm_context)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -286,6 +300,102 @@ class IntentTool(Tool):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class NamespacedTool(Tool):
|
||||||
|
"""A tool that wraps another tool, prepending a namespace.
|
||||||
|
|
||||||
|
This is used to support tools from multiple API. This tool dispatches
|
||||||
|
the original tool with the original non-namespaced name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, namespace: str, tool: Tool) -> None:
|
||||||
|
"""Init the class."""
|
||||||
|
self.namespace = namespace
|
||||||
|
self.name = f"{namespace}.{tool.name}"
|
||||||
|
self.description = tool.description
|
||||||
|
self.parameters = tool.parameters
|
||||||
|
self.tool = tool
|
||||||
|
|
||||||
|
async def async_call(
|
||||||
|
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||||
|
) -> JsonObjectType:
|
||||||
|
"""Handle the intent."""
|
||||||
|
return await self.tool.async_call(
|
||||||
|
hass,
|
||||||
|
ToolInput(
|
||||||
|
tool_name=self.tool.name,
|
||||||
|
tool_args=tool_input.tool_args,
|
||||||
|
id=tool_input.id,
|
||||||
|
),
|
||||||
|
llm_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MergedAPI(API):
|
||||||
|
"""An API that represents a merged view of multiple APIs."""
|
||||||
|
|
||||||
|
def __init__(self, llm_apis: list[API]) -> None:
|
||||||
|
"""Init the class."""
|
||||||
|
if not llm_apis:
|
||||||
|
raise ValueError("No APIs provided")
|
||||||
|
hass = llm_apis[0].hass
|
||||||
|
api_ids = [unicode_slug.slugify(api.id) for api in llm_apis]
|
||||||
|
if len(set(api_ids)) != len(api_ids):
|
||||||
|
raise ValueError("API IDs must be unique")
|
||||||
|
super().__init__(
|
||||||
|
hass=hass,
|
||||||
|
id="|".join(unicode_slug.slugify(api.id) for api in llm_apis),
|
||||||
|
name="Merged LLM API",
|
||||||
|
)
|
||||||
|
self.llm_apis = llm_apis
|
||||||
|
|
||||||
|
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
|
||||||
|
"""Return the instance of the API."""
|
||||||
|
# These usually don't do I/O and execute right away
|
||||||
|
llm_apis = [
|
||||||
|
await llm_api.async_get_api_instance(llm_context)
|
||||||
|
for llm_api in self.llm_apis
|
||||||
|
]
|
||||||
|
prompt_parts = []
|
||||||
|
tools: list[Tool] = []
|
||||||
|
for api_instance in llm_apis:
|
||||||
|
namespace = unicode_slug.slugify(api_instance.api.name)
|
||||||
|
prompt_parts.append(
|
||||||
|
f'Follow these instructions for tools from "{namespace}":\n'
|
||||||
|
)
|
||||||
|
prompt_parts.append(api_instance.api_prompt)
|
||||||
|
prompt_parts.append("\n\n")
|
||||||
|
tools.extend(
|
||||||
|
[NamespacedTool(namespace, tool) for tool in api_instance.tools]
|
||||||
|
)
|
||||||
|
|
||||||
|
return APIInstance(
|
||||||
|
api=self,
|
||||||
|
api_prompt="".join(prompt_parts),
|
||||||
|
llm_context=llm_context,
|
||||||
|
tools=tools,
|
||||||
|
custom_serializer=self._custom_serializer(llm_apis),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _custom_serializer(
|
||||||
|
self, llm_apis: list[APIInstance]
|
||||||
|
) -> Callable[[Any], Any] | None:
|
||||||
|
serializers = [
|
||||||
|
api_instance.custom_serializer
|
||||||
|
for api_instance in llm_apis
|
||||||
|
if api_instance.custom_serializer is not None
|
||||||
|
]
|
||||||
|
if not serializers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def merged(x: Any) -> Any:
|
||||||
|
for serializer in serializers:
|
||||||
|
if (result := serializer(x)) is not None:
|
||||||
|
return result
|
||||||
|
return x
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
class AssistAPI(API):
|
class AssistAPI(API):
|
||||||
"""API exposing Assist API to LLMs."""
|
"""API exposing Assist API to LLMs."""
|
||||||
|
|
||||||
|
@ -139,6 +139,48 @@ async def test_unknown_llm_api(
|
|||||||
assert exc_info.value.as_conversation_result().as_dict() == snapshot
|
assert exc_info.value.as_conversation_result().as_dict() == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_multiple_llm_apis(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_conversation_input: ConversationInput,
|
||||||
|
) -> None:
|
||||||
|
"""Test when we reference an LLM API."""
|
||||||
|
|
||||||
|
class MyTool(llm.Tool):
|
||||||
|
"""Test tool."""
|
||||||
|
|
||||||
|
name = "test_tool"
|
||||||
|
description = "Test function"
|
||||||
|
parameters = vol.Schema(
|
||||||
|
{vol.Optional("param1", description="Test parameters"): str}
|
||||||
|
)
|
||||||
|
|
||||||
|
class MyAPI(llm.API):
|
||||||
|
"""Test API."""
|
||||||
|
|
||||||
|
async def async_get_api_instance(
|
||||||
|
self, llm_context: llm.LLMContext
|
||||||
|
) -> llm.APIInstance:
|
||||||
|
"""Return a list of tools."""
|
||||||
|
return llm.APIInstance(self, "My API Prompt", llm_context, [MyTool()])
|
||||||
|
|
||||||
|
api = MyAPI(hass=hass, id="my-api", name="Test")
|
||||||
|
llm.async_register_api(hass, api)
|
||||||
|
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(hass) as session,
|
||||||
|
async_get_chat_log(hass, session, mock_conversation_input) as chat_log,
|
||||||
|
):
|
||||||
|
await chat_log.async_update_llm_data(
|
||||||
|
conversing_domain="test",
|
||||||
|
user_input=mock_conversation_input,
|
||||||
|
user_llm_hass_api=["assist", "my-api"],
|
||||||
|
user_llm_prompt=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chat_log.llm_api
|
||||||
|
assert chat_log.llm_api.api.id == "assist|my-api"
|
||||||
|
|
||||||
|
|
||||||
async def test_template_error(
|
async def test_template_error(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_conversation_input: ConversationInput,
|
mock_conversation_input: ConversationInput,
|
||||||
|
@ -25,6 +25,7 @@ from homeassistant.helpers import (
|
|||||||
)
|
)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
from homeassistant.util.json import JsonObjectType
|
||||||
|
|
||||||
from tests.common import MockConfigEntry, async_mock_service
|
from tests.common import MockConfigEntry, async_mock_service
|
||||||
|
|
||||||
@ -45,9 +46,12 @@ def llm_context() -> llm.LLMContext:
|
|||||||
class MyAPI(llm.API):
|
class MyAPI(llm.API):
|
||||||
"""Test API."""
|
"""Test API."""
|
||||||
|
|
||||||
|
prompt: str = ""
|
||||||
|
tools: list[llm.Tool] = []
|
||||||
|
|
||||||
async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
|
async def async_get_api_instance(self, _: llm.ToolInput) -> llm.APIInstance:
|
||||||
"""Return a list of tools."""
|
"""Return a list of tools."""
|
||||||
return llm.APIInstance(self, "", [], llm_context)
|
return llm.APIInstance(self, self.prompt, llm_context, self.tools)
|
||||||
|
|
||||||
|
|
||||||
async def test_get_api_no_existing(
|
async def test_get_api_no_existing(
|
||||||
@ -1326,3 +1330,57 @@ async def test_no_tools_exposed(hass: HomeAssistant) -> None:
|
|||||||
)
|
)
|
||||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||||
assert api.tools == []
|
assert api.tools == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_merged_api(hass: HomeAssistant, llm_context: llm.LLMContext) -> None:
|
||||||
|
"""Test an API instance that merges multiple llm apis."""
|
||||||
|
|
||||||
|
class MyTool(llm.Tool):
|
||||||
|
def __init__(self, name: str, description: str) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
async def async_call(
|
||||||
|
self, hass: HomeAssistant, tool_input: llm.ToolInput, _: llm.LLMContext
|
||||||
|
) -> JsonObjectType:
|
||||||
|
return {"result": {tool_input.tool_name: tool_input.tool_args}}
|
||||||
|
|
||||||
|
api1 = MyAPI(hass=hass, id="api-1", name="API 1")
|
||||||
|
api1.prompt = "This is prompt 1"
|
||||||
|
api1.tools = [MyTool(name="Tool_1", description="Description 1")]
|
||||||
|
llm.async_register_api(hass, api1)
|
||||||
|
|
||||||
|
api2 = MyAPI(hass=hass, id="api-2", name="API 2")
|
||||||
|
api2.prompt = "This is prompt 2"
|
||||||
|
api2.tools = [MyTool(name="Tool_2", description="Description 2")]
|
||||||
|
llm.async_register_api(hass, api2)
|
||||||
|
|
||||||
|
instance = await llm.async_get_api(hass, ["api-1", "api-2"], llm_context)
|
||||||
|
assert instance.api.id == "api-1|api-2"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
instance.api_prompt
|
||||||
|
== """Follow these instructions for tools from "api-1":
|
||||||
|
This is prompt 1
|
||||||
|
|
||||||
|
Follow these instructions for tools from "api-2":
|
||||||
|
This is prompt 2
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
assert [(tool.name, tool.description) for tool in instance.tools] == [
|
||||||
|
("api-1.Tool_1", "Description 1"),
|
||||||
|
("api-2.Tool_2", "Description 2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# The test tool returns back the provided arguments so we can verify
|
||||||
|
# the original tool is invoked with the correct tool name and args.
|
||||||
|
result = await instance.async_call_tool(
|
||||||
|
llm.ToolInput(tool_name="api-1.Tool_1", tool_args={"arg1": "value1"})
|
||||||
|
)
|
||||||
|
assert result == {"result": {"Tool_1": {"arg1": "value1"}}}
|
||||||
|
|
||||||
|
result = await instance.async_call_tool(
|
||||||
|
llm.ToolInput(tool_name="api-2.Tool_2", tool_args={"arg2": "value2"})
|
||||||
|
)
|
||||||
|
assert result == {"result": {"Tool_2": {"arg2": "value2"}}}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user