mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 02:49:40 +00:00
Update ollama to allow selecting mutiple LLM APIs (#142445)
* Update ollama to allow selecting mutiple LLM APIs * Update homeassistant/helpers/llm.py * Avoid gather since these don't do I/O --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
@@ -110,15 +110,29 @@ def async_register_api(hass: HomeAssistant, api: API) -> Callable[[], None]:
|
||||
|
||||
|
||||
async def async_get_api(
|
||||
hass: HomeAssistant, api_id: str, llm_context: LLMContext
|
||||
hass: HomeAssistant, api_id: str | list[str], llm_context: LLMContext
|
||||
) -> APIInstance:
|
||||
"""Get an API."""
|
||||
"""Get an API.
|
||||
|
||||
This returns a single APIInstance for one or more API ids, merging into
|
||||
a single instance of necessary.
|
||||
"""
|
||||
apis = _async_get_apis(hass)
|
||||
|
||||
if api_id not in apis:
|
||||
raise HomeAssistantError(f"API {api_id} not found")
|
||||
if isinstance(api_id, str):
|
||||
api_id = [api_id]
|
||||
|
||||
return await apis[api_id].async_get_api_instance(llm_context)
|
||||
for key in api_id:
|
||||
if key not in apis:
|
||||
raise HomeAssistantError(f"API {key} not found")
|
||||
|
||||
api: API
|
||||
if len(api_id) == 1:
|
||||
api = apis[api_id[0]]
|
||||
else:
|
||||
api = MergedAPI([apis[key] for key in api_id])
|
||||
|
||||
return await api.async_get_api_instance(llm_context)
|
||||
|
||||
|
||||
@callback
|
||||
@@ -286,6 +300,102 @@ class IntentTool(Tool):
|
||||
return response
|
||||
|
||||
|
||||
class NamespacedTool(Tool):
|
||||
"""A tool that wraps another tool, prepending a namespace.
|
||||
|
||||
This is used to support tools from multiple API. This tool dispatches
|
||||
the original tool with the original non-namespaced name.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, tool: Tool) -> None:
|
||||
"""Init the class."""
|
||||
self.namespace = namespace
|
||||
self.name = f"{namespace}.{tool.name}"
|
||||
self.description = tool.description
|
||||
self.parameters = tool.parameters
|
||||
self.tool = tool
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Handle the intent."""
|
||||
return await self.tool.async_call(
|
||||
hass,
|
||||
ToolInput(
|
||||
tool_name=self.tool.name,
|
||||
tool_args=tool_input.tool_args,
|
||||
id=tool_input.id,
|
||||
),
|
||||
llm_context,
|
||||
)
|
||||
|
||||
|
||||
class MergedAPI(API):
|
||||
"""An API that represents a merged view of multiple APIs."""
|
||||
|
||||
def __init__(self, llm_apis: list[API]) -> None:
|
||||
"""Init the class."""
|
||||
if not llm_apis:
|
||||
raise ValueError("No APIs provided")
|
||||
hass = llm_apis[0].hass
|
||||
api_ids = [unicode_slug.slugify(api.id) for api in llm_apis]
|
||||
if len(set(api_ids)) != len(api_ids):
|
||||
raise ValueError("API IDs must be unique")
|
||||
super().__init__(
|
||||
hass=hass,
|
||||
id="|".join(unicode_slug.slugify(api.id) for api in llm_apis),
|
||||
name="Merged LLM API",
|
||||
)
|
||||
self.llm_apis = llm_apis
|
||||
|
||||
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
|
||||
"""Return the instance of the API."""
|
||||
# These usually don't do I/O and execute right away
|
||||
llm_apis = [
|
||||
await llm_api.async_get_api_instance(llm_context)
|
||||
for llm_api in self.llm_apis
|
||||
]
|
||||
prompt_parts = []
|
||||
tools: list[Tool] = []
|
||||
for api_instance in llm_apis:
|
||||
namespace = unicode_slug.slugify(api_instance.api.name)
|
||||
prompt_parts.append(
|
||||
f'Follow these instructions for tools from "{namespace}":\n'
|
||||
)
|
||||
prompt_parts.append(api_instance.api_prompt)
|
||||
prompt_parts.append("\n\n")
|
||||
tools.extend(
|
||||
[NamespacedTool(namespace, tool) for tool in api_instance.tools]
|
||||
)
|
||||
|
||||
return APIInstance(
|
||||
api=self,
|
||||
api_prompt="".join(prompt_parts),
|
||||
llm_context=llm_context,
|
||||
tools=tools,
|
||||
custom_serializer=self._custom_serializer(llm_apis),
|
||||
)
|
||||
|
||||
def _custom_serializer(
|
||||
self, llm_apis: list[APIInstance]
|
||||
) -> Callable[[Any], Any] | None:
|
||||
serializers = [
|
||||
api_instance.custom_serializer
|
||||
for api_instance in llm_apis
|
||||
if api_instance.custom_serializer is not None
|
||||
]
|
||||
if not serializers:
|
||||
return None
|
||||
|
||||
def merged(x: Any) -> Any:
|
||||
for serializer in serializers:
|
||||
if (result := serializer(x)) is not None:
|
||||
return result
|
||||
return x
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
class AssistAPI(API):
|
||||
"""API exposing Assist API to LLMs."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user