mirror of
https://github.com/home-assistant/core.git
synced 2025-04-22 16:27:56 +00:00
Improve quality of ollama tool calling by repairing arguments (#122749)
* Improve quality of ollama function calling by repairing function call arguments * Fix formatting of the tests * Run ruff format on ollama conversation * Add test for non-string arguments
This commit is contained in:
parent
8b96c7873f
commit
4f5eab4646
@ -63,6 +63,34 @@ def _format_tool(
|
||||
return {"type": "function", "function": tool_spec}
|
||||
|
||||
|
||||
def _fix_invalid_arguments(value: Any) -> Any:
|
||||
"""Attempt to repair incorrectly formatted json function arguments.
|
||||
|
||||
Small models (for example llama3.1 8B) may produce invalid argument values
|
||||
which we attempt to repair here.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if (value.startswith("[") and value.endswith("]")) or (
|
||||
value.startswith("{") and value.endswith("}")
|
||||
):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Rewrite ollama tool arguments.
|
||||
|
||||
This function improves tool use quality by fixing common mistakes made by
|
||||
small local tool use models. This will repair invalid json arguments and
|
||||
omit unnecessary arguments with empty values that will fail intent parsing.
|
||||
"""
|
||||
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
|
||||
|
||||
|
||||
class OllamaConversationEntity(
|
||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||
):
|
||||
@ -255,7 +283,7 @@ class OllamaConversationEntity(
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=tool_call["function"]["arguments"],
|
||||
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Tests for the Ollama integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from ollama import Message, ResponseError
|
||||
@ -116,12 +117,33 @@ async def test_template_variables(
|
||||
assert "The user id is 12345." in prompt
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_args", "expected_tool_args"),
|
||||
[
|
||||
({"param1": "test_value"}, {"param1": "test_value"}),
|
||||
({"param1": 2}, {"param1": 2}),
|
||||
(
|
||||
{"param1": "test_value", "floor": ""},
|
||||
{"param1": "test_value"}, # Omit empty arguments
|
||||
),
|
||||
(
|
||||
{"domain": '["light"]'},
|
||||
{"domain": ["light"]}, # Repair invalid json arguments
|
||||
),
|
||||
(
|
||||
{"domain": "['light']"},
|
||||
{"domain": "['light']"}, # Preserve invalid json that can't be parsed
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
tool_args: dict[str, Any],
|
||||
expected_tool_args: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test function call from the assistant."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
@ -154,7 +176,7 @@ async def test_function_call(
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param1": "test_value"},
|
||||
"arguments": tool_args,
|
||||
}
|
||||
}
|
||||
],
|
||||
@ -183,7 +205,7 @@ async def test_function_call(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
tool_args=expected_tool_args,
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="ollama",
|
||||
|
Loading…
x
Reference in New Issue
Block a user