Ollama tool calling

This commit is contained in:
Allen Porter 2024-07-19 15:10:28 +00:00
parent 975cfa6457
commit 8f688ee079
5 changed files with 562 additions and 95 deletions

View File

@ -2,25 +2,31 @@
from __future__ import annotations
from collections.abc import Callable
import json
import logging
import time
from typing import Literal
from typing import Any, Literal
import ollama
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
template,
llm,
)
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
@ -34,12 +40,28 @@ from .const import (
DEFAULT_MAX_HISTORY,
DEFAULT_PROMPT,
DOMAIN,
KEEP_ALIVE_FOREVER,
MAX_HISTORY_SECONDS,
)
from .models import ExposedEntity, MessageHistory, MessageRole
from .models import MessageHistory, MessageRole
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
_LOGGER = logging.getLogger(__name__)
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {
"name": tool.name,
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
}
if tool.description:
tool_spec["description"] = tool.description
return tool_spec
async def async_setup_entry(
hass: HomeAssistant,
@ -90,10 +112,55 @@ class OllamaConversationEntity(
) -> conversation.ConversationResult:
"""Process a sentence."""
settings = {**self.entry.data, **self.entry.options}
options = self.entry.options
client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid.ulid_now()
model = settings[CONF_MODEL]
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None
tools: dict[str, dict[str, Any]] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
_LOGGER.debug("CONF_LLM_HASS_API=%s", settings.get(CONF_LLM_HASS_API))
if settings.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
settings[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = {
tool.name: _format_tool(tool, llm_api.custom_serializer)
for tool in llm_api.tools
}
_LOGGER.debug("tools=%s", tools)
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
# Look up message history
message_history: MessageHistory | None = None
@ -102,13 +169,23 @@ class OllamaConversationEntity(
# New history
#
# Render prompt and error out early if there's a problem
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
try:
prompt = self._generate_prompt(raw_prompt)
_LOGGER.debug("Prompt: %s", prompt)
prompt_parts = [
template.Template(
llm.BASE_PROMPT
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
@ -117,6 +194,16 @@ class OllamaConversationEntity(
response=intent_response, conversation_id=conversation_id
)
if llm_api:
_LOGGER.debug("llm api prompt parts")
prompt_parts.append(llm_api.api_prompt)
else:
_LOGGER.debug("no llm api prompt parts")
prompt = "\n".join(prompt_parts)
_LOGGER.debug("Prompt: %s", prompt)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
@ -146,32 +233,71 @@ class OllamaConversationEntity(
)
# Get response
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
stream=False,
# keep_alive requires specifying unit. In this case, seconds
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
stream=False,
tools=tools,
keep_alive=KEEP_ALIVE_FOREVER,
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response_message = response["message"]
message_history.messages.append(
ollama.Message(
role=response_message["role"], content=response_message["content"]
)
)
_LOGGER.debug("Response: %s", response)
response_message = response["message"]
tool_calls = response_message.get("tool_calls")
def message_convert(response_message: Any) -> ollama.Message:
msg = ollama.Message(
role=response_message["role"]
)
if content := response_message.get("content"):
msg["content"] = content
if tool_calls := response_message.get("tool_calls"):
msg["tool_calls"] = tool_calls
return msg
message_history.messages.append(message_convert(response_message))
if not tool_calls or not llm_api:
_LOGGER.debug("tool_calls=%s", tool_calls)
_LOGGER.debug("llm_api=%s", llm_api)
break
_LOGGER.debug("Response: %s", response_message.get("content"))
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=json.loads(tool_call["function"]["arguments"]),
)
_LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
_LOGGER.debug("Tool response: %s", tool_response)
message_history.messages.append(
ollama.Message(
role="tool", content=json.dumps(tool_response)
)
)
# Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
@ -204,62 +330,3 @@ class OllamaConversationEntity(
message_history.messages = [
message_history.messages[0]
] + message_history.messages[drop_index:]
def _generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"ha_language": self.hass.config.language,
"exposed_entities": self._get_exposed_entities(),
},
parse_result=False,
)
def _get_exposed_entities(self) -> list[ExposedEntity]:
"""Get state list of exposed entities."""
area_registry = ar.async_get(self.hass)
entity_registry = er.async_get(self.hass)
device_registry = dr.async_get(self.hass)
exposed_entities = []
exposed_states = [
state
for state in self.hass.states.async_all()
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
]
for state in exposed_states:
entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity_entry is not None:
# Add aliases
names.extend(entity_entry.aliases)
if entity_entry.area_id and (
area := area_registry.async_get_area(entity_entry.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity_entry.device_id and (
device := device_registry.async_get(entity_entry.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
exposed_entities.append(
ExposedEntity(
entity_id=state.entity_id,
state=state,
names=names,
area_names=area_names,
)
)
return exposed_entities

View File

@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest
from homeassistant.components import ollama
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component
from . import TEST_OPTIONS, TEST_USER_DATA
@ -25,6 +27,17 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
return entry
@pytest.fixture
def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
)
return mock_config_entry
@pytest.fixture
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
"""Initialize integration."""

View File

@ -0,0 +1,34 @@
# serializer version: 1
# name: test_unknown_hass_api
dict({
'conversation_id': None,
'response': IntentResponse(
card=dict({
}),
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
failed_results=list([
]),
intent=None,
intent_targets=list([
]),
language='en',
matched_states=list([
]),
reprompt=dict({
}),
response_type=<IntentResponseType.ERROR: 'error'>,
speech=dict({
'plain': dict({
'extra_data': None,
'speech': 'Error preparing LLM API: API non-existing not found',
}),
}),
speech_slots=dict({
}),
success_results=list([
]),
unmatched_states=list([
]),
),
})
# ---

View File

@ -1,24 +1,31 @@
"""Tests for the Ollama integration."""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import logging
from ollama import Message, ResponseError
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
llm,
)
from tests.common import MockConfigEntry
_LOGGER = logging.getLogger(__name__)
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat(
@ -124,6 +131,352 @@ async def test_chat(
detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
async def test_template_variables(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template variables work."""
context = Context(user_id="12345")
mock_user = Mock()
mock_user.id = "12345"
mock_user.name = "Test User"
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": (
"The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}."
),
},
)
with (
patch("ollama.AsyncClient.list"),
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat,
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert "The user name is Test User." in prompt
assert "The user id is 12345." in prompt
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test function call from the assistant."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, tools, **kwargs):
_LOGGER.debug("tools=%s", tools)
for message in messages:
if message["role"] == "tool":
return {
"message": {
"role": "assistant",
"content": "I have successfully called the function",
}
}
assert tools
return {
"message": {
"role": "assistant",
"content": "Calling tool",
"tool_calls": [{
"function": {
"name": "test_tool",
"arguments": '{"param1": "test_value"}'
}
}]
}
}
with patch(
"ollama.AsyncClient.chat",
side_effect=completion_result,
) as mock_chat:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert mock_chat.call_count == 2
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "I have successfully called the function"
)
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
llm.LLMContext(
platform="ollama",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
),
)
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
async def test_malformed_function_args(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test getting function args for an unknown function."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
if message["content"].startswith("TOOL_ARGS"):
return {
"message": {
"role": "assistant",
"content": "I was not able to call the function",
}
}
return {
"message": {
"role": "assistant",
"content": "TOOL_ARGS unknown_tool",
}
}
with patch(
"ollama.AsyncClient.chat",
side_effect=completion_result,
) as mock_chat:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert mock_tool.async_call.call_count == 0
assert mock_chat.call_count == 2
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "I was not able to call the function"
)
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
async def test_malformed_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test function call that was unrecognized."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
if message["content"].startswith("TOOL_CALL"):
return {
"message": {
"role": "assistant",
"content": "I was not able to call the function",
}
}
return {
"message": {
"role": "assistant",
"content": 'TOOL_CALL name="test_tool", param1="test_value"',
}
}
with patch(
"ollama.AsyncClient.chat",
side_effect=completion_result,
) as mock_chat:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert mock_tool.async_call.call_count == 0
assert mock_chat.call_count == 2
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "I was not able to call the function"
)
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
async def test_function_exception(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
) -> None:
"""Test function call with exception."""
agent_id = mock_config_entry_with_assist.entry_id
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
mock_get_tools.return_value = [mock_tool]
def completion_result(*args, messages, **kwargs):
for message in messages:
if message["content"].startswith("TOOL_CALL"):
return {
"message": {
"role": "assistant",
"content": "There was an error calling the function",
}
}
return {
"message": {
"role": "assistant",
"content": 'TOOL_CALL {"name": "test_tool", "parameters": {"param1": "test_value"}}',
}
}
with patch(
"ollama.AsyncClient.chat",
side_effect=completion_result,
) as mock_chat:
result = await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
)
assert mock_chat.call_count == 2
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.speech["plain"]["speech"]
== "There was an error calling the function"
)
mock_tool.async_call.assert_awaited_once_with(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "test_value"},
),
llm.LLMContext(
platform="ollama",
context=context,
user_prompt="Please call the test function",
language="en",
assistant="conversation",
device_id=None,
),
)
async def test_unknown_hass_api(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
snapshot: SnapshotAssertion,
mock_init_component,
) -> None:
"""Test when we reference an API that no longer exists."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: "non-existing",
},
)
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result == snapshot
async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component

View File

@ -128,9 +128,9 @@ asyncio.set_event_loop_policy(runner.HassEventLoopPolicy(False))
asyncio.set_event_loop_policy = lambda policy: None
def pytest_addoption(parser: pytest.Parser) -> None:
"""Register custom pytest options."""
parser.addoption("--dburl", action="store", default="sqlite://")
#def pytest_addoption(parser: pytest.Parser) -> None:
# """Register custom pytest options."""
# parser.addoption("--dburl", action="store", default="sqlite://")
def pytest_configure(config: pytest.Config) -> None: