mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Update ollama tool calls
This commit is contained in:
parent
8f688ee079
commit
eaeca423d4
@ -14,31 +14,19 @@ from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
template,
|
||||
llm,
|
||||
)
|
||||
from homeassistant.helpers import intent, llm, template
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .const import (
|
||||
CONF_KEEP_ALIVE,
|
||||
CONF_MAX_HISTORY,
|
||||
CONF_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_KEEP_ALIVE,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_PROMPT,
|
||||
DOMAIN,
|
||||
KEEP_ALIVE_FOREVER,
|
||||
MAX_HISTORY_SECONDS,
|
||||
@ -50,6 +38,7 @@ MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> dict[str, Any]:
|
||||
@ -60,7 +49,7 @@ def _format_tool(
|
||||
}
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return tool_spec
|
||||
return {"type": "function", "function": tool_spec}
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
@ -147,10 +136,9 @@ class OllamaConversationEntity(
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = {
|
||||
tool.name: _format_tool(tool, llm_api.custom_serializer)
|
||||
for tool in llm_api.tools
|
||||
}
|
||||
tools = [
|
||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||
]
|
||||
_LOGGER.debug("tools=%s", tools)
|
||||
|
||||
if (
|
||||
@ -200,7 +188,6 @@ class OllamaConversationEntity(
|
||||
else:
|
||||
_LOGGER.debug("no llm api prompt parts")
|
||||
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
|
||||
@ -259,9 +246,7 @@ class OllamaConversationEntity(
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
|
||||
def message_convert(response_message: Any) -> ollama.Message:
|
||||
msg = ollama.Message(
|
||||
role=response_message["role"]
|
||||
)
|
||||
msg = ollama.Message(role=response_message["role"])
|
||||
if content := response_message.get("content"):
|
||||
msg["content"] = content
|
||||
if tool_calls := response_message.get("tool_calls"):
|
||||
@ -276,10 +261,11 @@ class OllamaConversationEntity(
|
||||
break
|
||||
|
||||
_LOGGER.debug("Response: %s", response_message.get("content"))
|
||||
_LOGGER.debug("Tools calls: %s", tool_calls)
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=json.loads(tool_call["function"]["arguments"]),
|
||||
tool_args=tool_call["function"]["arguments"],
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
@ -294,9 +280,7 @@ class OllamaConversationEntity(
|
||||
|
||||
_LOGGER.debug("Tool response: %s", tool_response)
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role="tool", content=json.dumps(tool_response)
|
||||
)
|
||||
ollama.Message(role="tool", content=json.dumps(tool_response))
|
||||
)
|
||||
|
||||
# Create intent response
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Tests for the Ollama integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from ollama import Message, ResponseError
|
||||
import pytest
|
||||
@ -11,8 +11,7 @@ import voluptuous as vol
|
||||
from homeassistant.components import conversation, ollama
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
@ -27,6 +26,7 @@ from tests.common import MockConfigEntry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||
async def test_chat(
|
||||
hass: HomeAssistant,
|
||||
@ -131,6 +131,7 @@ async def test_chat(
|
||||
detail_event = trace_events[1]
|
||||
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
||||
|
||||
|
||||
async def test_template_variables(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
@ -205,17 +206,19 @@ async def test_function_call(
|
||||
"content": "I have successfully called the function",
|
||||
}
|
||||
}
|
||||
assert tools
|
||||
assert tools == {}
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Calling tool",
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"param1": "test_value"}'
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"param1": "test_value"}',
|
||||
}
|
||||
}
|
||||
}]
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user