From 4f5eab4646e0560188f2dca6a9daae0e48b11dac Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Wed, 31 Jul 2024 05:37:39 -0700 Subject: [PATCH] Improve quality of ollama tool calling by repairing arguments (#122749) * Improve quality of ollama function calling by repairing function call arguments * Fix formatting of the tests * Run ruff format on ollama conversation * Add test for non-string arguments --- .../components/ollama/conversation.py | 30 ++++++++++++++++++- tests/components/ollama/test_conversation.py | 26 ++++++++++++++-- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index ac367a5cf6a..f59e268394b 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -63,6 +63,34 @@ def _format_tool( return {"type": "function", "function": tool_spec} +def _fix_invalid_arguments(value: Any) -> Any: + """Attempt to repair incorrectly formatted json function arguments. + + Small models (for example llama3.1 8B) may produce invalid argument values + which we attempt to repair here. + """ + if not isinstance(value, str): + return value + if (value.startswith("[") and value.endswith("]")) or ( + value.startswith("{") and value.endswith("}") + ): + try: + return json.loads(value) + except json.decoder.JSONDecodeError: + pass + return value + + +def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]: + """Rewrite ollama tool arguments. + + This function improves tool use quality by fixing common mistakes made by + small local tool use models. This will repair invalid json arguments and + omit unnecessary arguments with empty values that will fail intent parsing. + """ + return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v} + + class OllamaConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -255,7 +283,7 @@ class OllamaConversationEntity( for tool_call in tool_calls: tool_input = llm.ToolInput( tool_name=tool_call["function"]["name"], - tool_args=tool_call["function"]["arguments"], + tool_args=_parse_tool_args(tool_call["function"]["arguments"]), ) _LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index 9be6f3b33a3..b5a94cc6f57 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -1,5 +1,6 @@ """Tests for the Ollama integration.""" +from typing import Any from unittest.mock import AsyncMock, Mock, patch from ollama import Message, ResponseError @@ -116,12 +117,33 @@ async def test_template_variables( assert "The user id is 12345." in prompt +@pytest.mark.parametrize( + ("tool_args", "expected_tool_args"), + [ + ({"param1": "test_value"}, {"param1": "test_value"}), + ({"param1": 2}, {"param1": 2}), + ( + {"param1": "test_value", "floor": ""}, + {"param1": "test_value"}, # Omit empty arguments + ), + ( + {"domain": '["light"]'}, + {"domain": ["light"]}, # Repair invalid json arguments + ), + ( + {"domain": "['light']"}, + {"domain": "['light']"}, # Preserve invalid json that can't be parsed + ), + ], +) @patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") async def test_function_call( mock_get_tools, hass: HomeAssistant, mock_config_entry_with_assist: MockConfigEntry, mock_init_component, + tool_args: dict[str, Any], + expected_tool_args: dict[str, Any], ) -> None: """Test function call from the assistant.""" agent_id = mock_config_entry_with_assist.entry_id @@ -154,7 +176,7 @@ async def test_function_call( { "function": { "name": "test_tool", - "arguments": {"param1": "test_value"}, + "arguments": tool_args, } } ], @@ -183,7 +205,7 @@ async def test_function_call( hass, llm.ToolInput( tool_name="test_tool", - tool_args={"param1": "test_value"}, + tool_args=expected_tool_args, ), llm.LLMContext( platform="ollama",