mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Improve quality of ollama tool calling by repairing arguments (#122749)
* Improve quality of ollama function calling by repairing function call arguments * Fix formatting of the tests * Run ruff format on ollama conversation * Add test for non-string arguments
This commit is contained in:
parent
8b96c7873f
commit
4f5eab4646
@ -63,6 +63,34 @@ def _format_tool(
|
|||||||
return {"type": "function", "function": tool_spec}
|
return {"type": "function", "function": tool_spec}
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_invalid_arguments(value: Any) -> Any:
|
||||||
|
"""Attempt to repair incorrectly formatted json function arguments.
|
||||||
|
|
||||||
|
Small models (for example llama3.1 8B) may produce invalid argument values
|
||||||
|
which we attempt to repair here.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if (value.startswith("[") and value.endswith("]")) or (
|
||||||
|
value.startswith("{") and value.endswith("}")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Rewrite ollama tool arguments.
|
||||||
|
|
||||||
|
This function improves tool use quality by fixing common mistakes made by
|
||||||
|
small local tool use models. This will repair invalid json arguments and
|
||||||
|
omit unnecessary arguments with empty values that will fail intent parsing.
|
||||||
|
"""
|
||||||
|
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
|
||||||
|
|
||||||
|
|
||||||
class OllamaConversationEntity(
|
class OllamaConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
@ -255,7 +283,7 @@ class OllamaConversationEntity(
|
|||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name=tool_call["function"]["name"],
|
tool_name=tool_call["function"]["name"],
|
||||||
tool_args=tool_call["function"]["arguments"],
|
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||||
)
|
)
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Tests for the Ollama integration."""
|
"""Tests for the Ollama integration."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
from ollama import Message, ResponseError
|
from ollama import Message, ResponseError
|
||||||
@ -116,12 +117,33 @@ async def test_template_variables(
|
|||||||
assert "The user id is 12345." in prompt
|
assert "The user id is 12345." in prompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("tool_args", "expected_tool_args"),
|
||||||
|
[
|
||||||
|
({"param1": "test_value"}, {"param1": "test_value"}),
|
||||||
|
({"param1": 2}, {"param1": 2}),
|
||||||
|
(
|
||||||
|
{"param1": "test_value", "floor": ""},
|
||||||
|
{"param1": "test_value"}, # Omit empty arguments
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"domain": '["light"]'},
|
||||||
|
{"domain": ["light"]}, # Repair invalid json arguments
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"domain": "['light']"},
|
||||||
|
{"domain": "['light']"}, # Preserve invalid json that can't be parsed
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||||
async def test_function_call(
|
async def test_function_call(
|
||||||
mock_get_tools,
|
mock_get_tools,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry_with_assist: MockConfigEntry,
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
|
tool_args: dict[str, Any],
|
||||||
|
expected_tool_args: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test function call from the assistant."""
|
"""Test function call from the assistant."""
|
||||||
agent_id = mock_config_entry_with_assist.entry_id
|
agent_id = mock_config_entry_with_assist.entry_id
|
||||||
@ -154,7 +176,7 @@ async def test_function_call(
|
|||||||
{
|
{
|
||||||
"function": {
|
"function": {
|
||||||
"name": "test_tool",
|
"name": "test_tool",
|
||||||
"arguments": {"param1": "test_value"},
|
"arguments": tool_args,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -183,7 +205,7 @@ async def test_function_call(
|
|||||||
hass,
|
hass,
|
||||||
llm.ToolInput(
|
llm.ToolInput(
|
||||||
tool_name="test_tool",
|
tool_name="test_tool",
|
||||||
tool_args={"param1": "test_value"},
|
tool_args=expected_tool_args,
|
||||||
),
|
),
|
||||||
llm.LLMContext(
|
llm.LLMContext(
|
||||||
platform="ollama",
|
platform="ollama",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user