Improve quality of ollama tool calling by repairing arguments (#122749)

* Improve quality of ollama function calling by repairing function call arguments

* Fix formatting of the tests

* Run ruff format on ollama conversation

* Add test for non-string arguments
This commit is contained in:
Allen Porter 2024-07-31 05:37:39 -07:00 committed by GitHub
parent 8b96c7873f
commit 4f5eab4646
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 3 deletions

View File

@ -63,6 +63,34 @@ def _format_tool(
return {"type": "function", "function": tool_spec}
def _fix_invalid_arguments(value: Any) -> Any:
"""Attempt to repair incorrectly formatted json function arguments.
Small models (for example llama3.1 8B) may produce invalid argument values
which we attempt to repair here.
"""
if not isinstance(value, str):
return value
if (value.startswith("[") and value.endswith("]")) or (
value.startswith("{") and value.endswith("}")
):
try:
return json.loads(value)
except json.decoder.JSONDecodeError:
pass
return value
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
"""Rewrite ollama tool arguments.
This function improves tool use quality by fixing common mistakes made by
small local tool use models. This will repair invalid json arguments and
omit unnecessary arguments with empty values that will fail intent parsing.
"""
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
class OllamaConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -255,7 +283,7 @@ class OllamaConversationEntity(
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=tool_call["function"]["arguments"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
_LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args

View File

@ -1,5 +1,6 @@
"""Tests for the Ollama integration."""
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
from ollama import Message, ResponseError
@ -116,12 +117,33 @@ async def test_template_variables(
assert "The user id is 12345." in prompt
@pytest.mark.parametrize(
("tool_args", "expected_tool_args"),
[
({"param1": "test_value"}, {"param1": "test_value"}),
({"param1": 2}, {"param1": 2}),
(
{"param1": "test_value", "floor": ""},
{"param1": "test_value"}, # Omit empty arguments
),
(
{"domain": '["light"]'},
{"domain": ["light"]}, # Repair invalid json arguments
),
(
{"domain": "['light']"},
{"domain": "['light']"}, # Preserve invalid json that can't be parsed
),
],
)
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
async def test_function_call(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_assist: MockConfigEntry,
mock_init_component,
tool_args: dict[str, Any],
expected_tool_args: dict[str, Any],
) -> None:
"""Test function call from the assistant."""
agent_id = mock_config_entry_with_assist.entry_id
@ -154,7 +176,7 @@ async def test_function_call(
{
"function": {
"name": "test_tool",
"arguments": {"param1": "test_value"},
"arguments": tool_args,
}
}
],
@ -183,7 +205,7 @@ async def test_function_call(
hass,
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "test_value"},
tool_args=expected_tool_args,
),
llm.LLMContext(
platform="ollama",