From bf74ba990aba4cf0964e13aea0cad233980880d7 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 30 Jun 2025 21:31:54 +0200 Subject: [PATCH] Split Ollama entity (#147769) --- .../components/ollama/conversation.py | 251 +---------------- homeassistant/components/ollama/entity.py | 258 ++++++++++++++++++ tests/components/ollama/test_conversation.py | 4 +- 3 files changed, 268 insertions(+), 245 deletions(-) create mode 100644 homeassistant/components/ollama/entity.py diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index beedb61f942..ae4de7d48a1 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -2,41 +2,18 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, AsyncIterator, Callable -import json -import logging -from typing import Any, Literal - -import ollama -from voluptuous_openapi import convert +from typing import Literal from homeassistant.components import assist_pipeline, conversation from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr, intent, llm +from homeassistant.helpers import intent from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import OllamaConfigEntry -from .const import ( - CONF_KEEP_ALIVE, - CONF_MAX_HISTORY, - CONF_MODEL, - CONF_NUM_CTX, - CONF_PROMPT, - CONF_THINK, - DEFAULT_KEEP_ALIVE, - DEFAULT_MAX_HISTORY, - DEFAULT_NUM_CTX, - DOMAIN, -) -from .models import MessageHistory, MessageRole - -# Max number of back and forth with the LLM to generate a response -MAX_TOOL_ITERATIONS = 10 - -_LOGGER = logging.getLogger(__name__) +from .const import CONF_PROMPT, DOMAIN +from .entity import OllamaBaseLLMEntity async def async_setup_entry( @@ -55,129 +32,10 @@ async def async_setup_entry( ) -def _format_tool( - tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None -) -> dict[str, Any]: - """Format tool specification.""" - tool_spec = { - "name": tool.name, - "parameters": convert(tool.parameters, custom_serializer=custom_serializer), - } - if tool.description: - tool_spec["description"] = tool.description - return {"type": "function", "function": tool_spec} - - -def _fix_invalid_arguments(value: Any) -> Any: - """Attempt to repair incorrectly formatted json function arguments. - - Small models (for example llama3.1 8B) may produce invalid argument values - which we attempt to repair here. - """ - if not isinstance(value, str): - return value - if (value.startswith("[") and value.endswith("]")) or ( - value.startswith("{") and value.endswith("}") - ): - try: - return json.loads(value) - except json.decoder.JSONDecodeError: - pass - return value - - -def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]: - """Rewrite ollama tool arguments. - - This function improves tool use quality by fixing common mistakes made by - small local tool use models. This will repair invalid json arguments and - omit unnecessary arguments with empty values that will fail intent parsing. - """ - return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v} - - -def _convert_content( - chat_content: ( - conversation.Content - | conversation.ToolResultContent - | conversation.AssistantContent - ), -) -> ollama.Message: - """Create tool response content.""" - if isinstance(chat_content, conversation.ToolResultContent): - return ollama.Message( - role=MessageRole.TOOL.value, - content=json.dumps(chat_content.tool_result), - ) - if isinstance(chat_content, conversation.AssistantContent): - return ollama.Message( - role=MessageRole.ASSISTANT.value, - content=chat_content.content, - tool_calls=[ - ollama.Message.ToolCall( - function=ollama.Message.ToolCall.Function( - name=tool_call.tool_name, - arguments=tool_call.tool_args, - ) - ) - for tool_call in chat_content.tool_calls or () - ], - ) - if isinstance(chat_content, conversation.UserContent): - return ollama.Message( - role=MessageRole.USER.value, - content=chat_content.content, - ) - if isinstance(chat_content, conversation.SystemContent): - return ollama.Message( - role=MessageRole.SYSTEM.value, - content=chat_content.content, - ) - raise TypeError(f"Unexpected content type: {type(chat_content)}") - - -async def _transform_stream( - result: AsyncIterator[ollama.ChatResponse], -) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: - """Transform the response stream into HA format. - - An Ollama streaming response may come in chunks like this: - - response: message=Message(role="assistant", content="Paris") - response: message=Message(role="assistant", content=".") - response: message=Message(role="assistant", content=""), done: True, done_reason: "stop" - response: message=Message(role="assistant", tool_calls=[...]) - response: message=Message(role="assistant", content=""), done: True, done_reason: "stop" - - This generator conforms to the chatlog delta stream expectations in that it - yields deltas, then the role only once the response is done. - """ - - new_msg = True - async for response in result: - _LOGGER.debug("Received response: %s", response) - response_message = response["message"] - chunk: conversation.AssistantContentDeltaDict = {} - if new_msg: - new_msg = False - chunk["role"] = "assistant" - if (tool_calls := response_message.get("tool_calls")) is not None: - chunk["tool_calls"] = [ - llm.ToolInput( - tool_name=tool_call["function"]["name"], - tool_args=_parse_tool_args(tool_call["function"]["arguments"]), - ) - for tool_call in tool_calls - ] - if (content := response_message.get("content")) is not None: - chunk["content"] = content - if response_message.get("done"): - new_msg = True - yield chunk - - class OllamaConversationEntity( - conversation.ConversationEntity, conversation.AbstractConversationAgent + conversation.ConversationEntity, + conversation.AbstractConversationAgent, + OllamaBaseLLMEntity, ): """Ollama conversation agent.""" @@ -185,17 +43,7 @@ class OllamaConversationEntity( def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" - self.entry = entry - self.subentry = subentry - self._attr_name = subentry.title - self._attr_unique_id = subentry.subentry_id - self._attr_device_info = dr.DeviceInfo( - identifiers={(DOMAIN, subentry.subentry_id)}, - name=subentry.title, - manufacturer="Ollama", - model=entry.data[CONF_MODEL], - entry_type=dr.DeviceEntryType.SERVICE, - ) + super().__init__(entry, subentry) if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL @@ -255,89 +103,6 @@ class OllamaConversationEntity( continue_conversation=chat_log.continue_conversation, ) - async def _async_handle_chat_log( - self, - chat_log: conversation.ChatLog, - ) -> None: - """Generate an answer for the chat log.""" - settings = {**self.entry.data, **self.subentry.data} - - client = self.entry.runtime_data - model = settings[CONF_MODEL] - - tools: list[dict[str, Any]] | None = None - if chat_log.llm_api: - tools = [ - _format_tool(tool, chat_log.llm_api.custom_serializer) - for tool in chat_log.llm_api.tools - ] - - message_history: MessageHistory = MessageHistory( - [_convert_content(content) for content in chat_log.content] - ) - max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) - self._trim_history(message_history, max_messages) - - # Get response - # To prevent infinite loops, we limit the number of iterations - for _iteration in range(MAX_TOOL_ITERATIONS): - try: - response_generator = await client.chat( - model=model, - # Make a copy of the messages because we mutate the list later - messages=list(message_history.messages), - tools=tools, - stream=True, - # keep_alive requires specifying unit. In this case, seconds - keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", - options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, - think=settings.get(CONF_THINK), - ) - except (ollama.RequestError, ollama.ResponseError) as err: - _LOGGER.error("Unexpected error talking to Ollama server: %s", err) - raise HomeAssistantError( - f"Sorry, I had a problem talking to the Ollama server: {err}" - ) from err - - message_history.messages.extend( - [ - _convert_content(content) - async for content in chat_log.async_add_delta_content_stream( - self.entity_id, _transform_stream(response_generator) - ) - ] - ) - - if not chat_log.unresponded_tool_results: - break - - def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: - """Trims excess messages from a single history. - - This sets the max history to allow a configurable size history may take - up in the context window. - - Note that some messages in the history may not be from ollama only, and - may come from other anents, so the assumptions here may not strictly hold, - but generally should be effective. - """ - if max_messages < 1: - # Keep all messages - return - - # Ignore the in progress user message - num_previous_rounds = message_history.num_user_messages - 1 - if num_previous_rounds >= max_messages: - # Trim history but keep system prompt (first message). - # Every other message should be an assistant message, so keep 2x - # message objects. Also keep the last in progress user message - num_keep = 2 * max_messages + 1 - drop_index = len(message_history.messages) - num_keep - message_history.messages = [ - message_history.messages[0], - *message_history.messages[drop_index:], - ] - async def _async_entry_update_listener( self, hass: HomeAssistant, entry: ConfigEntry ) -> None: diff --git a/homeassistant/components/ollama/entity.py b/homeassistant/components/ollama/entity.py new file mode 100644 index 00000000000..a577bf77573 --- /dev/null +++ b/homeassistant/components/ollama/entity.py @@ -0,0 +1,258 @@ +"""Base entity for the Ollama integration.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, AsyncIterator, Callable +import json +import logging +from typing import Any + +import ollama +from voluptuous_openapi import convert + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigSubentry +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr, llm +from homeassistant.helpers.entity import Entity + +from . import OllamaConfigEntry +from .const import ( + CONF_KEEP_ALIVE, + CONF_MAX_HISTORY, + CONF_MODEL, + CONF_NUM_CTX, + CONF_THINK, + DEFAULT_KEEP_ALIVE, + DEFAULT_MAX_HISTORY, + DEFAULT_NUM_CTX, + DOMAIN, +) +from .models import MessageHistory, MessageRole + +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 + +_LOGGER = logging.getLogger(__name__) + + +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> dict[str, Any]: + """Format tool specification.""" + tool_spec = { + "name": tool.name, + "parameters": convert(tool.parameters, custom_serializer=custom_serializer), + } + if tool.description: + tool_spec["description"] = tool.description + return {"type": "function", "function": tool_spec} + + +def _fix_invalid_arguments(value: Any) -> Any: + """Attempt to repair incorrectly formatted json function arguments. + + Small models (for example llama3.1 8B) may produce invalid argument values + which we attempt to repair here. + """ + if not isinstance(value, str): + return value + if (value.startswith("[") and value.endswith("]")) or ( + value.startswith("{") and value.endswith("}") + ): + try: + return json.loads(value) + except json.decoder.JSONDecodeError: + pass + return value + + +def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]: + """Rewrite ollama tool arguments. + + This function improves tool use quality by fixing common mistakes made by + small local tool use models. This will repair invalid json arguments and + omit unnecessary arguments with empty values that will fail intent parsing. + """ + return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v} + + +def _convert_content( + chat_content: ( + conversation.Content + | conversation.ToolResultContent + | conversation.AssistantContent + ), +) -> ollama.Message: + """Create tool response content.""" + if isinstance(chat_content, conversation.ToolResultContent): + return ollama.Message( + role=MessageRole.TOOL.value, + content=json.dumps(chat_content.tool_result), + ) + if isinstance(chat_content, conversation.AssistantContent): + return ollama.Message( + role=MessageRole.ASSISTANT.value, + content=chat_content.content, + tool_calls=[ + ollama.Message.ToolCall( + function=ollama.Message.ToolCall.Function( + name=tool_call.tool_name, + arguments=tool_call.tool_args, + ) + ) + for tool_call in chat_content.tool_calls or () + ], + ) + if isinstance(chat_content, conversation.UserContent): + return ollama.Message( + role=MessageRole.USER.value, + content=chat_content.content, + ) + if isinstance(chat_content, conversation.SystemContent): + return ollama.Message( + role=MessageRole.SYSTEM.value, + content=chat_content.content, + ) + raise TypeError(f"Unexpected content type: {type(chat_content)}") + + +async def _transform_stream( + result: AsyncIterator[ollama.ChatResponse], +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + """Transform the response stream into HA format. + + An Ollama streaming response may come in chunks like this: + + response: message=Message(role="assistant", content="Paris") + response: message=Message(role="assistant", content=".") + response: message=Message(role="assistant", content=""), done: True, done_reason: "stop" + response: message=Message(role="assistant", tool_calls=[...]) + response: message=Message(role="assistant", content=""), done: True, done_reason: "stop" + + This generator conforms to the chatlog delta stream expectations in that it + yields deltas, then the role only once the response is done. + """ + + new_msg = True + async for response in result: + _LOGGER.debug("Received response: %s", response) + response_message = response["message"] + chunk: conversation.AssistantContentDeltaDict = {} + if new_msg: + new_msg = False + chunk["role"] = "assistant" + if (tool_calls := response_message.get("tool_calls")) is not None: + chunk["tool_calls"] = [ + llm.ToolInput( + tool_name=tool_call["function"]["name"], + tool_args=_parse_tool_args(tool_call["function"]["arguments"]), + ) + for tool_call in tool_calls + ] + if (content := response_message.get("content")) is not None: + chunk["content"] = content + if response_message.get("done"): + new_msg = True + yield chunk + + +class OllamaBaseLLMEntity(Entity): + """Ollama base LLM entity.""" + + def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None: + """Initialize the entity.""" + self.entry = entry + self.subentry = subentry + self._attr_name = subentry.title + self._attr_unique_id = subentry.subentry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, subentry.subentry_id)}, + name=subentry.title, + manufacturer="Ollama", + model=entry.data[CONF_MODEL], + entry_type=dr.DeviceEntryType.SERVICE, + ) + + async def _async_handle_chat_log( + self, + chat_log: conversation.ChatLog, + ) -> None: + """Generate an answer for the chat log.""" + settings = {**self.entry.data, **self.subentry.data} + + client = self.entry.runtime_data + model = settings[CONF_MODEL] + + tools: list[dict[str, Any]] | None = None + if chat_log.llm_api: + tools = [ + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools + ] + + message_history: MessageHistory = MessageHistory( + [_convert_content(content) for content in chat_log.content] + ) + max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) + self._trim_history(message_history, max_messages) + + # Get response + # To prevent infinite loops, we limit the number of iterations + for _iteration in range(MAX_TOOL_ITERATIONS): + try: + response_generator = await client.chat( + model=model, + # Make a copy of the messages because we mutate the list later + messages=list(message_history.messages), + tools=tools, + stream=True, + # keep_alive requires specifying unit. In this case, seconds + keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", + options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, + think=settings.get(CONF_THINK), + ) + except (ollama.RequestError, ollama.ResponseError) as err: + _LOGGER.error("Unexpected error talking to Ollama server: %s", err) + raise HomeAssistantError( + f"Sorry, I had a problem talking to the Ollama server: {err}" + ) from err + + message_history.messages.extend( + [ + _convert_content(content) + async for content in chat_log.async_add_delta_content_stream( + self.entity_id, _transform_stream(response_generator) + ) + ] + ) + + if not chat_log.unresponded_tool_results: + break + + def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: + """Trims excess messages from a single history. + + This sets the max history to allow a configurable size history may take + up in the context window. + + Note that some messages in the history may not be from ollama only, and + may come from other anents, so the assumptions here may not strictly hold, + but generally should be effective. + """ + if max_messages < 1: + # Keep all messages + return + + # Ignore the in progress user message + num_previous_rounds = message_history.num_user_messages - 1 + if num_previous_rounds >= max_messages: + # Trim history but keep system prompt (first message). + # Every other message should be an assistant message, so keep 2x + # message objects. Also keep the last in progress user message + num_keep = 2 * max_messages + 1 + drop_index = len(message_history.messages) - num_keep + message_history.messages = [ + message_history.messages[0], + *message_history.messages[drop_index:], + ] diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index cebb185bd08..d33fffe7152 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -206,7 +206,7 @@ async def test_template_variables( ), ], ) -@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +@patch("homeassistant.components.ollama.entity.llm.AssistAPI._async_get_tools") async def test_function_call( mock_get_tools, hass: HomeAssistant, @@ -293,7 +293,7 @@ async def test_function_call( ) -@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +@patch("homeassistant.components.ollama.entity.llm.AssistAPI._async_get_tools") async def test_function_exception( mock_get_tools, hass: HomeAssistant,