mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Split Ollama entity (#147769)
This commit is contained in:
parent
70856bd92a
commit
bf74ba990a
@ -2,41 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Callable
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
import ollama
|
||||
from voluptuous_openapi import convert
|
||||
from typing import Literal
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from . import OllamaConfigEntry
|
||||
from .const import (
|
||||
CONF_KEEP_ALIVE,
|
||||
CONF_MAX_HISTORY,
|
||||
CONF_MODEL,
|
||||
CONF_NUM_CTX,
|
||||
CONF_PROMPT,
|
||||
CONF_THINK,
|
||||
DEFAULT_KEEP_ALIVE,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_NUM_CTX,
|
||||
DOMAIN,
|
||||
)
|
||||
from .models import MessageHistory, MessageRole
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
from .const import CONF_PROMPT, DOMAIN
|
||||
from .entity import OllamaBaseLLMEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@ -55,129 +32,10 @@ async def async_setup_entry(
|
||||
)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Format tool specification."""
|
||||
tool_spec = {
|
||||
"name": tool.name,
|
||||
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
}
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return {"type": "function", "function": tool_spec}
|
||||
|
||||
|
||||
def _fix_invalid_arguments(value: Any) -> Any:
|
||||
"""Attempt to repair incorrectly formatted json function arguments.
|
||||
|
||||
Small models (for example llama3.1 8B) may produce invalid argument values
|
||||
which we attempt to repair here.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if (value.startswith("[") and value.endswith("]")) or (
|
||||
value.startswith("{") and value.endswith("}")
|
||||
):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Rewrite ollama tool arguments.
|
||||
|
||||
This function improves tool use quality by fixing common mistakes made by
|
||||
small local tool use models. This will repair invalid json arguments and
|
||||
omit unnecessary arguments with empty values that will fail intent parsing.
|
||||
"""
|
||||
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
|
||||
|
||||
|
||||
def _convert_content(
|
||||
chat_content: (
|
||||
conversation.Content
|
||||
| conversation.ToolResultContent
|
||||
| conversation.AssistantContent
|
||||
),
|
||||
) -> ollama.Message:
|
||||
"""Create tool response content."""
|
||||
if isinstance(chat_content, conversation.ToolResultContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.TOOL.value,
|
||||
content=json.dumps(chat_content.tool_result),
|
||||
)
|
||||
if isinstance(chat_content, conversation.AssistantContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.ASSISTANT.value,
|
||||
content=chat_content.content,
|
||||
tool_calls=[
|
||||
ollama.Message.ToolCall(
|
||||
function=ollama.Message.ToolCall.Function(
|
||||
name=tool_call.tool_name,
|
||||
arguments=tool_call.tool_args,
|
||||
)
|
||||
)
|
||||
for tool_call in chat_content.tool_calls or ()
|
||||
],
|
||||
)
|
||||
if isinstance(chat_content, conversation.UserContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.USER.value,
|
||||
content=chat_content.content,
|
||||
)
|
||||
if isinstance(chat_content, conversation.SystemContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.SYSTEM.value,
|
||||
content=chat_content.content,
|
||||
)
|
||||
raise TypeError(f"Unexpected content type: {type(chat_content)}")
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
result: AsyncIterator[ollama.ChatResponse],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the response stream into HA format.
|
||||
|
||||
An Ollama streaming response may come in chunks like this:
|
||||
|
||||
response: message=Message(role="assistant", content="Paris")
|
||||
response: message=Message(role="assistant", content=".")
|
||||
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||
response: message=Message(role="assistant", tool_calls=[...])
|
||||
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||
|
||||
This generator conforms to the chatlog delta stream expectations in that it
|
||||
yields deltas, then the role only once the response is done.
|
||||
"""
|
||||
|
||||
new_msg = True
|
||||
async for response in result:
|
||||
_LOGGER.debug("Received response: %s", response)
|
||||
response_message = response["message"]
|
||||
chunk: conversation.AssistantContentDeltaDict = {}
|
||||
if new_msg:
|
||||
new_msg = False
|
||||
chunk["role"] = "assistant"
|
||||
if (tool_calls := response_message.get("tool_calls")) is not None:
|
||||
chunk["tool_calls"] = [
|
||||
llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if (content := response_message.get("content")) is not None:
|
||||
chunk["content"] = content
|
||||
if response_message.get("done"):
|
||||
new_msg = True
|
||||
yield chunk
|
||||
|
||||
|
||||
class OllamaConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
conversation.ConversationEntity,
|
||||
conversation.AbstractConversationAgent,
|
||||
OllamaBaseLLMEntity,
|
||||
):
|
||||
"""Ollama conversation agent."""
|
||||
|
||||
@ -185,17 +43,7 @@ class OllamaConversationEntity(
|
||||
|
||||
def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.entry = entry
|
||||
self.subentry = subentry
|
||||
self._attr_name = subentry.title
|
||||
self._attr_unique_id = subentry.subentry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
name=subentry.title,
|
||||
manufacturer="Ollama",
|
||||
model=entry.data[CONF_MODEL],
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
super().__init__(entry, subentry)
|
||||
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||
self._attr_supported_features = (
|
||||
conversation.ConversationEntityFeature.CONTROL
|
||||
@ -255,89 +103,6 @@ class OllamaConversationEntity(
|
||||
continue_conversation=chat_log.continue_conversation,
|
||||
)
|
||||
|
||||
async def _async_handle_chat_log(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
settings = {**self.entry.data, **self.subentry.data}
|
||||
|
||||
client = self.entry.runtime_data
|
||||
model = settings[CONF_MODEL]
|
||||
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
message_history: MessageHistory = MessageHistory(
|
||||
[_convert_content(content) for content in chat_log.content]
|
||||
)
|
||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||
self._trim_history(message_history, max_messages)
|
||||
|
||||
# Get response
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
response_generator = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
tools=tools,
|
||||
stream=True,
|
||||
# keep_alive requires specifying unit. In this case, seconds
|
||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
||||
think=settings.get(CONF_THINK),
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
raise HomeAssistantError(
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}"
|
||||
) from err
|
||||
|
||||
message_history.messages.extend(
|
||||
[
|
||||
_convert_content(content)
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id, _transform_stream(response_generator)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
||||
"""Trims excess messages from a single history.
|
||||
|
||||
This sets the max history to allow a configurable size history may take
|
||||
up in the context window.
|
||||
|
||||
Note that some messages in the history may not be from ollama only, and
|
||||
may come from other anents, so the assumptions here may not strictly hold,
|
||||
but generally should be effective.
|
||||
"""
|
||||
if max_messages < 1:
|
||||
# Keep all messages
|
||||
return
|
||||
|
||||
# Ignore the in progress user message
|
||||
num_previous_rounds = message_history.num_user_messages - 1
|
||||
if num_previous_rounds >= max_messages:
|
||||
# Trim history but keep system prompt (first message).
|
||||
# Every other message should be an assistant message, so keep 2x
|
||||
# message objects. Also keep the last in progress user message
|
||||
num_keep = 2 * max_messages + 1
|
||||
drop_index = len(message_history.messages) - num_keep
|
||||
message_history.messages = [
|
||||
message_history.messages[0],
|
||||
*message_history.messages[drop_index:],
|
||||
]
|
||||
|
||||
async def _async_entry_update_listener(
|
||||
self, hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> None:
|
||||
|
258
homeassistant/components/ollama/entity.py
Normal file
258
homeassistant/components/ollama/entity.py
Normal file
@ -0,0 +1,258 @@
|
||||
"""Base entity for the Ollama integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Callable
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import ollama
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr, llm
|
||||
from homeassistant.helpers.entity import Entity
|
||||
|
||||
from . import OllamaConfigEntry
|
||||
from .const import (
|
||||
CONF_KEEP_ALIVE,
|
||||
CONF_MAX_HISTORY,
|
||||
CONF_MODEL,
|
||||
CONF_NUM_CTX,
|
||||
CONF_THINK,
|
||||
DEFAULT_KEEP_ALIVE,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_NUM_CTX,
|
||||
DOMAIN,
|
||||
)
|
||||
from .models import MessageHistory, MessageRole
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Format tool specification."""
|
||||
tool_spec = {
|
||||
"name": tool.name,
|
||||
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
}
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return {"type": "function", "function": tool_spec}
|
||||
|
||||
|
||||
def _fix_invalid_arguments(value: Any) -> Any:
|
||||
"""Attempt to repair incorrectly formatted json function arguments.
|
||||
|
||||
Small models (for example llama3.1 8B) may produce invalid argument values
|
||||
which we attempt to repair here.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if (value.startswith("[") and value.endswith("]")) or (
|
||||
value.startswith("{") and value.endswith("}")
|
||||
):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Rewrite ollama tool arguments.
|
||||
|
||||
This function improves tool use quality by fixing common mistakes made by
|
||||
small local tool use models. This will repair invalid json arguments and
|
||||
omit unnecessary arguments with empty values that will fail intent parsing.
|
||||
"""
|
||||
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
|
||||
|
||||
|
||||
def _convert_content(
|
||||
chat_content: (
|
||||
conversation.Content
|
||||
| conversation.ToolResultContent
|
||||
| conversation.AssistantContent
|
||||
),
|
||||
) -> ollama.Message:
|
||||
"""Create tool response content."""
|
||||
if isinstance(chat_content, conversation.ToolResultContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.TOOL.value,
|
||||
content=json.dumps(chat_content.tool_result),
|
||||
)
|
||||
if isinstance(chat_content, conversation.AssistantContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.ASSISTANT.value,
|
||||
content=chat_content.content,
|
||||
tool_calls=[
|
||||
ollama.Message.ToolCall(
|
||||
function=ollama.Message.ToolCall.Function(
|
||||
name=tool_call.tool_name,
|
||||
arguments=tool_call.tool_args,
|
||||
)
|
||||
)
|
||||
for tool_call in chat_content.tool_calls or ()
|
||||
],
|
||||
)
|
||||
if isinstance(chat_content, conversation.UserContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.USER.value,
|
||||
content=chat_content.content,
|
||||
)
|
||||
if isinstance(chat_content, conversation.SystemContent):
|
||||
return ollama.Message(
|
||||
role=MessageRole.SYSTEM.value,
|
||||
content=chat_content.content,
|
||||
)
|
||||
raise TypeError(f"Unexpected content type: {type(chat_content)}")
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
result: AsyncIterator[ollama.ChatResponse],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the response stream into HA format.
|
||||
|
||||
An Ollama streaming response may come in chunks like this:
|
||||
|
||||
response: message=Message(role="assistant", content="Paris")
|
||||
response: message=Message(role="assistant", content=".")
|
||||
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||
response: message=Message(role="assistant", tool_calls=[...])
|
||||
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
|
||||
|
||||
This generator conforms to the chatlog delta stream expectations in that it
|
||||
yields deltas, then the role only once the response is done.
|
||||
"""
|
||||
|
||||
new_msg = True
|
||||
async for response in result:
|
||||
_LOGGER.debug("Received response: %s", response)
|
||||
response_message = response["message"]
|
||||
chunk: conversation.AssistantContentDeltaDict = {}
|
||||
if new_msg:
|
||||
new_msg = False
|
||||
chunk["role"] = "assistant"
|
||||
if (tool_calls := response_message.get("tool_calls")) is not None:
|
||||
chunk["tool_calls"] = [
|
||||
llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if (content := response_message.get("content")) is not None:
|
||||
chunk["content"] = content
|
||||
if response_message.get("done"):
|
||||
new_msg = True
|
||||
yield chunk
|
||||
|
||||
|
||||
class OllamaBaseLLMEntity(Entity):
|
||||
"""Ollama base LLM entity."""
|
||||
|
||||
def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the entity."""
|
||||
self.entry = entry
|
||||
self.subentry = subentry
|
||||
self._attr_name = subentry.title
|
||||
self._attr_unique_id = subentry.subentry_id
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
name=subentry.title,
|
||||
manufacturer="Ollama",
|
||||
model=entry.data[CONF_MODEL],
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
|
||||
async def _async_handle_chat_log(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
settings = {**self.entry.data, **self.subentry.data}
|
||||
|
||||
client = self.entry.runtime_data
|
||||
model = settings[CONF_MODEL]
|
||||
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
message_history: MessageHistory = MessageHistory(
|
||||
[_convert_content(content) for content in chat_log.content]
|
||||
)
|
||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||
self._trim_history(message_history, max_messages)
|
||||
|
||||
# Get response
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
response_generator = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
tools=tools,
|
||||
stream=True,
|
||||
# keep_alive requires specifying unit. In this case, seconds
|
||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
||||
think=settings.get(CONF_THINK),
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
raise HomeAssistantError(
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}"
|
||||
) from err
|
||||
|
||||
message_history.messages.extend(
|
||||
[
|
||||
_convert_content(content)
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id, _transform_stream(response_generator)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
||||
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
||||
"""Trims excess messages from a single history.
|
||||
|
||||
This sets the max history to allow a configurable size history may take
|
||||
up in the context window.
|
||||
|
||||
Note that some messages in the history may not be from ollama only, and
|
||||
may come from other anents, so the assumptions here may not strictly hold,
|
||||
but generally should be effective.
|
||||
"""
|
||||
if max_messages < 1:
|
||||
# Keep all messages
|
||||
return
|
||||
|
||||
# Ignore the in progress user message
|
||||
num_previous_rounds = message_history.num_user_messages - 1
|
||||
if num_previous_rounds >= max_messages:
|
||||
# Trim history but keep system prompt (first message).
|
||||
# Every other message should be an assistant message, so keep 2x
|
||||
# message objects. Also keep the last in progress user message
|
||||
num_keep = 2 * max_messages + 1
|
||||
drop_index = len(message_history.messages) - num_keep
|
||||
message_history.messages = [
|
||||
message_history.messages[0],
|
||||
*message_history.messages[drop_index:],
|
||||
]
|
@ -206,7 +206,7 @@ async def test_template_variables(
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
@patch("homeassistant.components.ollama.entity.llm.AssistAPI._async_get_tools")
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
@ -293,7 +293,7 @@ async def test_function_call(
|
||||
)
|
||||
|
||||
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
@patch("homeassistant.components.ollama.entity.llm.AssistAPI._async_get_tools")
|
||||
async def test_function_exception(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
|
Loading…
x
Reference in New Issue
Block a user