Update ollama tool calls

This commit is contained in:
Allen Porter 2024-07-20 04:21:44 +00:00
parent 8f688ee079
commit eaeca423d4
2 changed files with 22 additions and 35 deletions

View File

@ -14,31 +14,19 @@ from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
template,
llm,
)
from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid
from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_PROMPT,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_PROMPT,
DOMAIN,
KEEP_ALIVE_FOREVER,
MAX_HISTORY_SECONDS,
@ -50,6 +38,7 @@ MAX_TOOL_ITERATIONS = 10
_LOGGER = logging.getLogger(__name__)
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
@ -60,7 +49,7 @@ def _format_tool(
}
if tool.description:
tool_spec["description"] = tool.description
return tool_spec
return {"type": "function", "function": tool_spec}
async def async_setup_entry(
@ -147,10 +136,9 @@ class OllamaConversationEntity(
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = {
tool.name: _format_tool(tool, llm_api.custom_serializer)
for tool in llm_api.tools
}
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
_LOGGER.debug("tools=%s", tools)
if (
@ -200,7 +188,6 @@ class OllamaConversationEntity(
else:
_LOGGER.debug("no llm api prompt parts")
prompt = "\n".join(prompt_parts)
_LOGGER.debug("Prompt: %s", prompt)
@ -259,9 +246,7 @@ class OllamaConversationEntity(
tool_calls = response_message.get("tool_calls")
def message_convert(response_message: Any) -> ollama.Message:
msg = ollama.Message(
role=response_message["role"]
)
msg = ollama.Message(role=response_message["role"])
if content := response_message.get("content"):
msg["content"] = content
if tool_calls := response_message.get("tool_calls"):
@ -276,10 +261,11 @@ class OllamaConversationEntity(
break
_LOGGER.debug("Response: %s", response_message.get("content"))
_LOGGER.debug("Tools calls: %s", tool_calls)
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=json.loads(tool_call["function"]["arguments"]),
tool_args=tool_call["function"]["arguments"],
)
_LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
@ -294,9 +280,7 @@ class OllamaConversationEntity(
_LOGGER.debug("Tool response: %s", tool_response)
message_history.messages.append(
ollama.Message(
role="tool", content=json.dumps(tool_response)
)
ollama.Message(role="tool", content=json.dumps(tool_response))
)
# Create intent response

View File

@ -1,7 +1,7 @@
"""Tests for the Ollama integration."""
from unittest.mock import AsyncMock, Mock, patch
import logging
from unittest.mock import AsyncMock, Mock, patch
from ollama import Message, ResponseError
import pytest
@ -11,8 +11,7 @@ import voluptuous as vol
from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.const import ATTR_FRIENDLY_NAME, CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
@ -27,6 +26,7 @@ from tests.common import MockConfigEntry
_LOGGER = logging.getLogger(__name__)
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
async def test_chat(
hass: HomeAssistant,
@ -131,6 +131,7 @@ async def test_chat(
detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
async def test_template_variables(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
@ -205,17 +206,19 @@ async def test_function_call(
"content": "I have successfully called the function",
}
}
assert tools
assert tools == {}
return {
"message": {
"role": "assistant",
"content": "Calling tool",
"tool_calls": [{
"function": {
"name": "test_tool",
"arguments": '{"param1": "test_value"}'
"tool_calls": [
{
"function": {
"name": "test_tool",
"arguments": '{"param1": "test_value"}',
}
}
}]
],
}
}