Update ollama tool calls

This commit is contained in:
Allen Porter 2024-07-20 04:21:44 +00:00
parent 8f688ee079
commit eaeca423d4
2 changed files with 22 additions and 35 deletions

View File

@ -14,31 +14,19 @@ from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import ( from homeassistant.helpers import intent, llm, template
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
template,
llm,
)
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid from homeassistant.util import ulid
from .const import ( from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, CONF_MAX_HISTORY,
CONF_MODEL, CONF_MODEL,
CONF_PROMPT, CONF_PROMPT,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY, DEFAULT_MAX_HISTORY,
DEFAULT_PROMPT,
DOMAIN, DOMAIN,
KEEP_ALIVE_FOREVER, KEEP_ALIVE_FOREVER,
MAX_HISTORY_SECONDS, MAX_HISTORY_SECONDS,
@ -50,6 +38,7 @@ MAX_TOOL_ITERATIONS = 10
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def _format_tool( def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -60,7 +49,7 @@ def _format_tool(
} }
if tool.description: if tool.description:
tool_spec["description"] = tool.description tool_spec["description"] = tool.description
return tool_spec return {"type": "function", "function": tool_spec}
async def async_setup_entry( async def async_setup_entry(
@ -147,10 +136,9 @@ class OllamaConversationEntity(
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id response=intent_response, conversation_id=user_input.conversation_id
) )
tools = { tools = [
tool.name: _format_tool(tool, llm_api.custom_serializer) _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
for tool in llm_api.tools ]
}
_LOGGER.debug("tools=%s", tools) _LOGGER.debug("tools=%s", tools)
if ( if (
@ -200,7 +188,6 @@ class OllamaConversationEntity(
else: else:
_LOGGER.debug("no llm api prompt parts") _LOGGER.debug("no llm api prompt parts")
prompt = "\n".join(prompt_parts) prompt = "\n".join(prompt_parts)
_LOGGER.debug("Prompt: %s", prompt) _LOGGER.debug("Prompt: %s", prompt)
@ -259,9 +246,7 @@ class OllamaConversationEntity(
tool_calls = response_message.get("tool_calls") tool_calls = response_message.get("tool_calls")
def message_convert(response_message: Any) -> ollama.Message: def message_convert(response_message: Any) -> ollama.Message:
msg = ollama.Message( msg = ollama.Message(role=response_message["role"])
role=response_message["role"]
)
if content := response_message.get("content"): if content := response_message.get("content"):
msg["content"] = content msg["content"] = content
if tool_calls := response_message.get("tool_calls"): if tool_calls := response_message.get("tool_calls"):
@ -276,10 +261,11 @@ class OllamaConversationEntity(
break break
_LOGGER.debug("Response: %s", response_message.get("content")) _LOGGER.debug("Response: %s", response_message.get("content"))
_LOGGER.debug("Tools calls: %s", tool_calls)
for tool_call in tool_calls: for tool_call in tool_calls:
tool_input = llm.ToolInput( tool_input = llm.ToolInput(
tool_name=tool_call["function"]["name"], tool_name=tool_call["function"]["name"],
tool_args=json.loads(tool_call["function"]["arguments"]), tool_args=tool_call["function"]["arguments"],
) )
_LOGGER.debug( _LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
@ -294,9 +280,7 @@ class OllamaConversationEntity(
_LOGGER.debug("Tool response: %s", tool_response) _LOGGER.debug("Tool response: %s", tool_response)
message_history.messages.append( message_history.messages.append(
ollama.Message( ollama.Message(role="tool", content=json.dumps(tool_response))
role="tool", content=json.dumps(tool_response)
)
) )
# Create intent response # Create intent response

View File

@ -1,7 +1,7 @@
"""Tests for the Ollama integration.""" """Tests for the Ollama integration."""
from unittest.mock import AsyncMock, Mock, patch
import logging import logging
from unittest.mock import AsyncMock, Mock, patch
from ollama import Message, ResponseError from ollama import Message, ResponseError
import pytest import pytest
@ -11,8 +11,7 @@ import voluptuous as vol
from homeassistant.components import conversation, ollama from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import ATTR_FRIENDLY_NAME, CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
@ -27,6 +26,7 @@ from tests.common import MockConfigEntry
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat( async def test_chat(
hass: HomeAssistant, hass: HomeAssistant,
@ -131,6 +131,7 @@ async def test_chat(
detail_event = trace_events[1] detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"] assert "The current time is" in detail_event["data"]["messages"][0]["content"]
async def test_template_variables( async def test_template_variables(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
@ -205,17 +206,19 @@ async def test_function_call(
"content": "I have successfully called the function", "content": "I have successfully called the function",
} }
} }
assert tools assert tools == {}
return { return {
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "Calling tool", "content": "Calling tool",
"tool_calls": [{ "tool_calls": [
"function": { {
"name": "test_tool", "function": {
"arguments": '{"param1": "test_value"}' "name": "test_tool",
"arguments": '{"param1": "test_value"}',
}
} }
}] ],
} }
} }