Split Ollama entity (#147769)

This commit is contained in:
Paulus Schoutsen 2025-06-30 21:31:54 +02:00 committed by GitHub
parent 70856bd92a
commit bf74ba990a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 268 additions and 245 deletions

View File

@ -2,41 +2,18 @@
from __future__ import annotations
from collections.abc import AsyncGenerator, AsyncIterator, Callable
import json
import logging
from typing import Any, Literal
import ollama
from voluptuous_openapi import convert
from typing import Literal
from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OllamaConfigEntry
from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_NUM_CTX,
CONF_PROMPT,
CONF_THINK,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_NUM_CTX,
DOMAIN,
)
from .models import MessageHistory, MessageRole
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
_LOGGER = logging.getLogger(__name__)
from .const import CONF_PROMPT, DOMAIN
from .entity import OllamaBaseLLMEntity
async def async_setup_entry(
@ -55,129 +32,10 @@ async def async_setup_entry(
)
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {
"name": tool.name,
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
}
if tool.description:
tool_spec["description"] = tool.description
return {"type": "function", "function": tool_spec}
def _fix_invalid_arguments(value: Any) -> Any:
"""Attempt to repair incorrectly formatted json function arguments.
Small models (for example llama3.1 8B) may produce invalid argument values
which we attempt to repair here.
"""
if not isinstance(value, str):
return value
if (value.startswith("[") and value.endswith("]")) or (
value.startswith("{") and value.endswith("}")
):
try:
return json.loads(value)
except json.decoder.JSONDecodeError:
pass
return value
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
"""Rewrite ollama tool arguments.
This function improves tool use quality by fixing common mistakes made by
small local tool use models. This will repair invalid json arguments and
omit unnecessary arguments with empty values that will fail intent parsing.
"""
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
def _convert_content(
chat_content: (
conversation.Content
| conversation.ToolResultContent
| conversation.AssistantContent
),
) -> ollama.Message:
"""Create tool response content."""
if isinstance(chat_content, conversation.ToolResultContent):
return ollama.Message(
role=MessageRole.TOOL.value,
content=json.dumps(chat_content.tool_result),
)
if isinstance(chat_content, conversation.AssistantContent):
return ollama.Message(
role=MessageRole.ASSISTANT.value,
content=chat_content.content,
tool_calls=[
ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(
name=tool_call.tool_name,
arguments=tool_call.tool_args,
)
)
for tool_call in chat_content.tool_calls or ()
],
)
if isinstance(chat_content, conversation.UserContent):
return ollama.Message(
role=MessageRole.USER.value,
content=chat_content.content,
)
if isinstance(chat_content, conversation.SystemContent):
return ollama.Message(
role=MessageRole.SYSTEM.value,
content=chat_content.content,
)
raise TypeError(f"Unexpected content type: {type(chat_content)}")
async def _transform_stream(
result: AsyncIterator[ollama.ChatResponse],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format.
An Ollama streaming response may come in chunks like this:
response: message=Message(role="assistant", content="Paris")
response: message=Message(role="assistant", content=".")
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
response: message=Message(role="assistant", tool_calls=[...])
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
This generator conforms to the chatlog delta stream expectations in that it
yields deltas, then the role only once the response is done.
"""
new_msg = True
async for response in result:
_LOGGER.debug("Received response: %s", response)
response_message = response["message"]
chunk: conversation.AssistantContentDeltaDict = {}
if new_msg:
new_msg = False
chunk["role"] = "assistant"
if (tool_calls := response_message.get("tool_calls")) is not None:
chunk["tool_calls"] = [
llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
for tool_call in tool_calls
]
if (content := response_message.get("content")) is not None:
chunk["content"] = content
if response_message.get("done"):
new_msg = True
yield chunk
class OllamaConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
conversation.ConversationEntity,
conversation.AbstractConversationAgent,
OllamaBaseLLMEntity,
):
"""Ollama conversation agent."""
@ -185,17 +43,7 @@ class OllamaConversationEntity(
def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent."""
self.entry = entry
self.subentry = subentry
self._attr_name = subentry.title
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
manufacturer="Ollama",
model=entry.data[CONF_MODEL],
entry_type=dr.DeviceEntryType.SERVICE,
)
super().__init__(entry, subentry)
if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
@ -255,89 +103,6 @@ class OllamaConversationEntity(
continue_conversation=chat_log.continue_conversation,
)
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
) -> None:
"""Generate an answer for the chat log."""
settings = {**self.entry.data, **self.subentry.data}
client = self.entry.runtime_data
model = settings[CONF_MODEL]
tools: list[dict[str, Any]] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
message_history: MessageHistory = MessageHistory(
[_convert_content(content) for content in chat_log.content]
)
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages)
# Get response
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
response_generator = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
tools=tools,
stream=True,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
think=settings.get(CONF_THINK),
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
raise HomeAssistantError(
f"Sorry, I had a problem talking to the Ollama server: {err}"
) from err
message_history.messages.extend(
[
_convert_content(content)
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_stream(response_generator)
)
]
)
if not chat_log.unresponded_tool_results:
break
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history.
This sets the max history to allow a configurable size history may take
up in the context window.
Note that some messages in the history may not be from ollama only, and
may come from other anents, so the assumptions here may not strictly hold,
but generally should be effective.
"""
if max_messages < 1:
# Keep all messages
return
# Ignore the in progress user message
num_previous_rounds = message_history.num_user_messages - 1
if num_previous_rounds >= max_messages:
# Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x
# message objects. Also keep the last in progress user message
num_keep = 2 * max_messages + 1
drop_index = len(message_history.messages) - num_keep
message_history.messages = [
message_history.messages[0],
*message_history.messages[drop_index:],
]
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:

View File

@ -0,0 +1,258 @@
"""Base entity for the Ollama integration."""
from __future__ import annotations
from collections.abc import AsyncGenerator, AsyncIterator, Callable
import json
import logging
from typing import Any
import ollama
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
from . import OllamaConfigEntry
from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_NUM_CTX,
CONF_THINK,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_NUM_CTX,
DOMAIN,
)
from .models import MessageHistory, MessageRole
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
_LOGGER = logging.getLogger(__name__)
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {
"name": tool.name,
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
}
if tool.description:
tool_spec["description"] = tool.description
return {"type": "function", "function": tool_spec}
def _fix_invalid_arguments(value: Any) -> Any:
"""Attempt to repair incorrectly formatted json function arguments.
Small models (for example llama3.1 8B) may produce invalid argument values
which we attempt to repair here.
"""
if not isinstance(value, str):
return value
if (value.startswith("[") and value.endswith("]")) or (
value.startswith("{") and value.endswith("}")
):
try:
return json.loads(value)
except json.decoder.JSONDecodeError:
pass
return value
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
"""Rewrite ollama tool arguments.
This function improves tool use quality by fixing common mistakes made by
small local tool use models. This will repair invalid json arguments and
omit unnecessary arguments with empty values that will fail intent parsing.
"""
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
def _convert_content(
chat_content: (
conversation.Content
| conversation.ToolResultContent
| conversation.AssistantContent
),
) -> ollama.Message:
"""Create tool response content."""
if isinstance(chat_content, conversation.ToolResultContent):
return ollama.Message(
role=MessageRole.TOOL.value,
content=json.dumps(chat_content.tool_result),
)
if isinstance(chat_content, conversation.AssistantContent):
return ollama.Message(
role=MessageRole.ASSISTANT.value,
content=chat_content.content,
tool_calls=[
ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(
name=tool_call.tool_name,
arguments=tool_call.tool_args,
)
)
for tool_call in chat_content.tool_calls or ()
],
)
if isinstance(chat_content, conversation.UserContent):
return ollama.Message(
role=MessageRole.USER.value,
content=chat_content.content,
)
if isinstance(chat_content, conversation.SystemContent):
return ollama.Message(
role=MessageRole.SYSTEM.value,
content=chat_content.content,
)
raise TypeError(f"Unexpected content type: {type(chat_content)}")
async def _transform_stream(
result: AsyncIterator[ollama.ChatResponse],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format.
An Ollama streaming response may come in chunks like this:
response: message=Message(role="assistant", content="Paris")
response: message=Message(role="assistant", content=".")
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
response: message=Message(role="assistant", tool_calls=[...])
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
This generator conforms to the chatlog delta stream expectations in that it
yields deltas, then the role only once the response is done.
"""
new_msg = True
async for response in result:
_LOGGER.debug("Received response: %s", response)
response_message = response["message"]
chunk: conversation.AssistantContentDeltaDict = {}
if new_msg:
new_msg = False
chunk["role"] = "assistant"
if (tool_calls := response_message.get("tool_calls")) is not None:
chunk["tool_calls"] = [
llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
for tool_call in tool_calls
]
if (content := response_message.get("content")) is not None:
chunk["content"] = content
if response_message.get("done"):
new_msg = True
yield chunk
class OllamaBaseLLMEntity(Entity):
"""Ollama base LLM entity."""
def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the entity."""
self.entry = entry
self.subentry = subentry
self._attr_name = subentry.title
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
manufacturer="Ollama",
model=entry.data[CONF_MODEL],
entry_type=dr.DeviceEntryType.SERVICE,
)
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
) -> None:
"""Generate an answer for the chat log."""
settings = {**self.entry.data, **self.subentry.data}
client = self.entry.runtime_data
model = settings[CONF_MODEL]
tools: list[dict[str, Any]] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
message_history: MessageHistory = MessageHistory(
[_convert_content(content) for content in chat_log.content]
)
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages)
# Get response
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
response_generator = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
tools=tools,
stream=True,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
think=settings.get(CONF_THINK),
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
raise HomeAssistantError(
f"Sorry, I had a problem talking to the Ollama server: {err}"
) from err
message_history.messages.extend(
[
_convert_content(content)
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_stream(response_generator)
)
]
)
if not chat_log.unresponded_tool_results:
break
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history.
This sets the max history to allow a configurable size history may take
up in the context window.
Note that some messages in the history may not be from ollama only, and
may come from other anents, so the assumptions here may not strictly hold,
but generally should be effective.
"""
if max_messages < 1:
# Keep all messages
return
# Ignore the in progress user message
num_previous_rounds = message_history.num_user_messages - 1
if num_previous_rounds >= max_messages:
# Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x
# message objects. Also keep the last in progress user message
num_keep = 2 * max_messages + 1
drop_index = len(message_history.messages) - num_keep
message_history.messages = [
message_history.messages[0],
*message_history.messages[drop_index:],
]

View File

@ -206,7 +206,7 @@ async def test_template_variables(
),
],
)
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
@patch("homeassistant.components.ollama.entity.llm.AssistAPI._async_get_tools")
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
@ -293,7 +293,7 @@ async def test_function_call(
)
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
@patch("homeassistant.components.ollama.entity.llm.AssistAPI._async_get_tools")
async def test_function_exception(
mock_get_tools,
hass: HomeAssistant,