From eaeca423d4760098fb6e75a58f728de39dc2606a Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 20 Jul 2024 04:21:44 +0000 Subject: [PATCH] Update ollama tool calls --- .../components/ollama/conversation.py | 36 ++++++------------- tests/components/ollama/test_conversation.py | 21 ++++++----- 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index 797adfc9f41..f342833820c 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -14,31 +14,19 @@ from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.components.conversation import trace -from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL -from homeassistant.const import MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError, TemplateError -from homeassistant.helpers import ( - area_registry as ar, - device_registry as dr, - entity_registry as er, - intent, - template, - llm, -) +from homeassistant.helpers import intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid from .const import ( - CONF_KEEP_ALIVE, CONF_MAX_HISTORY, CONF_MODEL, CONF_PROMPT, - DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, - DEFAULT_PROMPT, DOMAIN, KEEP_ALIVE_FOREVER, MAX_HISTORY_SECONDS, @@ -50,6 +38,7 @@ MAX_TOOL_ITERATIONS = 10 _LOGGER = logging.getLogger(__name__) + def _format_tool( tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None ) -> dict[str, Any]: @@ -60,7 +49,7 @@ def _format_tool( } if tool.description: tool_spec["description"] = tool.description - return tool_spec + return {"type": "function", "function": tool_spec} async def async_setup_entry( @@ -147,10 +136,9 @@ class OllamaConversationEntity( return conversation.ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) - tools = { - tool.name: _format_tool(tool, llm_api.custom_serializer) - for tool in llm_api.tools - } + tools = [ + _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools + ] _LOGGER.debug("tools=%s", tools) if ( @@ -200,7 +188,6 @@ class OllamaConversationEntity( else: _LOGGER.debug("no llm api prompt parts") - prompt = "\n".join(prompt_parts) _LOGGER.debug("Prompt: %s", prompt) @@ -259,9 +246,7 @@ class OllamaConversationEntity( tool_calls = response_message.get("tool_calls") def message_convert(response_message: Any) -> ollama.Message: - msg = ollama.Message( - role=response_message["role"] - ) + msg = ollama.Message(role=response_message["role"]) if content := response_message.get("content"): msg["content"] = content if tool_calls := response_message.get("tool_calls"): @@ -276,10 +261,11 @@ class OllamaConversationEntity( break _LOGGER.debug("Response: %s", response_message.get("content")) + _LOGGER.debug("Tools calls: %s", tool_calls) for tool_call in tool_calls: tool_input = llm.ToolInput( tool_name=tool_call["function"]["name"], - tool_args=json.loads(tool_call["function"]["arguments"]), + tool_args=tool_call["function"]["arguments"], ) _LOGGER.debug( "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args @@ -294,9 +280,7 @@ class OllamaConversationEntity( _LOGGER.debug("Tool response: %s", tool_response) message_history.messages.append( - ollama.Message( - role="tool", content=json.dumps(tool_response) - ) + ollama.Message(role="tool", content=json.dumps(tool_response)) ) # Create intent response diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index 1161f46cec5..3a070a8360a 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -1,7 +1,7 @@ """Tests for the Ollama integration.""" -from unittest.mock import AsyncMock, Mock, patch import logging +from unittest.mock import AsyncMock, Mock, patch from ollama import Message, ResponseError import pytest @@ -11,8 +11,7 @@ import voluptuous as vol from homeassistant.components import conversation, ollama from homeassistant.components.conversation import trace from homeassistant.components.homeassistant.exposed_entities import async_expose_entity -from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL -from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL +from homeassistant.const import ATTR_FRIENDLY_NAME, CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( @@ -27,6 +26,7 @@ from tests.common import MockConfigEntry _LOGGER = logging.getLogger(__name__) + @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) async def test_chat( hass: HomeAssistant, @@ -131,6 +131,7 @@ async def test_chat( detail_event = trace_events[1] assert "The current time is" in detail_event["data"]["messages"][0]["content"] + async def test_template_variables( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: @@ -205,17 +206,19 @@ async def test_function_call( "content": "I have successfully called the function", } } - assert tools + assert tools == {} return { "message": { "role": "assistant", "content": "Calling tool", - "tool_calls": [{ - "function": { - "name": "test_tool", - "arguments": '{"param1": "test_value"}' + "tool_calls": [ + { + "function": { + "name": "test_tool", + "arguments": '{"param1": "test_value"}', + } } - }] + ], } }