mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Update ollama tool calls
This commit is contained in:
parent
8f688ee079
commit
eaeca423d4
@ -14,31 +14,19 @@ from voluptuous_openapi import convert
|
|||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.const import MATCH_ALL
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import intent, llm, template
|
||||||
area_registry as ar,
|
|
||||||
device_registry as dr,
|
|
||||||
entity_registry as er,
|
|
||||||
intent,
|
|
||||||
template,
|
|
||||||
llm,
|
|
||||||
)
|
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid
|
from homeassistant.util import ulid
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_KEEP_ALIVE,
|
|
||||||
CONF_MAX_HISTORY,
|
CONF_MAX_HISTORY,
|
||||||
CONF_MODEL,
|
CONF_MODEL,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
DEFAULT_KEEP_ALIVE,
|
|
||||||
DEFAULT_MAX_HISTORY,
|
DEFAULT_MAX_HISTORY,
|
||||||
DEFAULT_PROMPT,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
KEEP_ALIVE_FOREVER,
|
KEEP_ALIVE_FOREVER,
|
||||||
MAX_HISTORY_SECONDS,
|
MAX_HISTORY_SECONDS,
|
||||||
@ -50,6 +38,7 @@ MAX_TOOL_ITERATIONS = 10
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(
|
def _format_tool(
|
||||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@ -60,7 +49,7 @@ def _format_tool(
|
|||||||
}
|
}
|
||||||
if tool.description:
|
if tool.description:
|
||||||
tool_spec["description"] = tool.description
|
tool_spec["description"] = tool.description
|
||||||
return tool_spec
|
return {"type": "function", "function": tool_spec}
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
@ -147,10 +136,9 @@ class OllamaConversationEntity(
|
|||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
response=intent_response, conversation_id=user_input.conversation_id
|
||||||
)
|
)
|
||||||
tools = {
|
tools = [
|
||||||
tool.name: _format_tool(tool, llm_api.custom_serializer)
|
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
||||||
for tool in llm_api.tools
|
]
|
||||||
}
|
|
||||||
_LOGGER.debug("tools=%s", tools)
|
_LOGGER.debug("tools=%s", tools)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -200,7 +188,6 @@ class OllamaConversationEntity(
|
|||||||
else:
|
else:
|
||||||
_LOGGER.debug("no llm api prompt parts")
|
_LOGGER.debug("no llm api prompt parts")
|
||||||
|
|
||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
prompt = "\n".join(prompt_parts)
|
||||||
_LOGGER.debug("Prompt: %s", prompt)
|
_LOGGER.debug("Prompt: %s", prompt)
|
||||||
|
|
||||||
@ -259,9 +246,7 @@ class OllamaConversationEntity(
|
|||||||
tool_calls = response_message.get("tool_calls")
|
tool_calls = response_message.get("tool_calls")
|
||||||
|
|
||||||
def message_convert(response_message: Any) -> ollama.Message:
|
def message_convert(response_message: Any) -> ollama.Message:
|
||||||
msg = ollama.Message(
|
msg = ollama.Message(role=response_message["role"])
|
||||||
role=response_message["role"]
|
|
||||||
)
|
|
||||||
if content := response_message.get("content"):
|
if content := response_message.get("content"):
|
||||||
msg["content"] = content
|
msg["content"] = content
|
||||||
if tool_calls := response_message.get("tool_calls"):
|
if tool_calls := response_message.get("tool_calls"):
|
||||||
@ -276,10 +261,11 @@ class OllamaConversationEntity(
|
|||||||
break
|
break
|
||||||
|
|
||||||
_LOGGER.debug("Response: %s", response_message.get("content"))
|
_LOGGER.debug("Response: %s", response_message.get("content"))
|
||||||
|
_LOGGER.debug("Tools calls: %s", tool_calls)
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
tool_input = llm.ToolInput(
|
tool_input = llm.ToolInput(
|
||||||
tool_name=tool_call["function"]["name"],
|
tool_name=tool_call["function"]["name"],
|
||||||
tool_args=json.loads(tool_call["function"]["arguments"]),
|
tool_args=tool_call["function"]["arguments"],
|
||||||
)
|
)
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
@ -294,9 +280,7 @@ class OllamaConversationEntity(
|
|||||||
|
|
||||||
_LOGGER.debug("Tool response: %s", tool_response)
|
_LOGGER.debug("Tool response: %s", tool_response)
|
||||||
message_history.messages.append(
|
message_history.messages.append(
|
||||||
ollama.Message(
|
ollama.Message(role="tool", content=json.dumps(tool_response))
|
||||||
role="tool", content=json.dumps(tool_response)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create intent response
|
# Create intent response
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Tests for the Ollama integration."""
|
"""Tests for the Ollama integration."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
|
||||||
import logging
|
import logging
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
from ollama import Message, ResponseError
|
from ollama import Message, ResponseError
|
||||||
import pytest
|
import pytest
|
||||||
@ -11,8 +11,7 @@ import voluptuous as vol
|
|||||||
from homeassistant.components import conversation, ollama
|
from homeassistant.components import conversation, ollama
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import ATTR_FRIENDLY_NAME, CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
@ -27,6 +26,7 @@ from tests.common import MockConfigEntry
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||||
async def test_chat(
|
async def test_chat(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -131,6 +131,7 @@ async def test_chat(
|
|||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
async def test_template_variables(
|
async def test_template_variables(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -205,17 +206,19 @@ async def test_function_call(
|
|||||||
"content": "I have successfully called the function",
|
"content": "I have successfully called the function",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert tools
|
assert tools == {}
|
||||||
return {
|
return {
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "Calling tool",
|
"content": "Calling tool",
|
||||||
"tool_calls": [{
|
"tool_calls": [
|
||||||
"function": {
|
{
|
||||||
"name": "test_tool",
|
"function": {
|
||||||
"arguments": '{"param1": "test_value"}'
|
"name": "test_tool",
|
||||||
|
"arguments": '{"param1": "test_value"}',
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}]
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user