mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 06:07:17 +00:00
Ollama tool calling
This commit is contained in:
parent
975cfa6457
commit
8f688ee079
@ -2,25 +2,31 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import ollama
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
template,
|
||||
llm,
|
||||
)
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util import ulid
|
||||
@ -34,12 +40,28 @@ from .const import (
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_PROMPT,
|
||||
DOMAIN,
|
||||
KEEP_ALIVE_FOREVER,
|
||||
MAX_HISTORY_SECONDS,
|
||||
)
|
||||
from .models import ExposedEntity, MessageHistory, MessageRole
|
||||
from .models import MessageHistory, MessageRole
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
def _format_tool(
|
||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Format tool specification."""
|
||||
tool_spec = {
|
||||
"name": tool.name,
|
||||
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
|
||||
}
|
||||
if tool.description:
|
||||
tool_spec["description"] = tool.description
|
||||
return tool_spec
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
@ -90,10 +112,55 @@ class OllamaConversationEntity(
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
settings = {**self.entry.data, **self.entry.options}
|
||||
options = self.entry.options
|
||||
|
||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||
model = settings[CONF_MODEL]
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
llm_api: llm.APIInstance | None = None
|
||||
tools: dict[str, dict[str, Any]] | None = None
|
||||
user_name: str | None = None
|
||||
llm_context = llm.LLMContext(
|
||||
platform=DOMAIN,
|
||||
context=user_input.context,
|
||||
user_prompt=user_input.text,
|
||||
language=user_input.language,
|
||||
assistant=conversation.DOMAIN,
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
_LOGGER.debug("CONF_LLM_HASS_API=%s", settings.get(CONF_LLM_HASS_API))
|
||||
if settings.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
settings[CONF_LLM_HASS_API],
|
||||
llm_context,
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Error preparing LLM API: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=user_input.conversation_id
|
||||
)
|
||||
tools = {
|
||||
tool.name: _format_tool(tool, llm_api.custom_serializer)
|
||||
for tool in llm_api.tools
|
||||
}
|
||||
_LOGGER.debug("tools=%s", tools)
|
||||
|
||||
if (
|
||||
user_input.context
|
||||
and user_input.context.user_id
|
||||
and (
|
||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||
)
|
||||
):
|
||||
user_name = user.name
|
||||
|
||||
# Look up message history
|
||||
message_history: MessageHistory | None = None
|
||||
@ -102,13 +169,23 @@ class OllamaConversationEntity(
|
||||
# New history
|
||||
#
|
||||
# Render prompt and error out early if there's a problem
|
||||
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
try:
|
||||
prompt = self._generate_prompt(raw_prompt)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
prompt_parts = [
|
||||
template.Template(
|
||||
llm.BASE_PROMPT
|
||||
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||
self.hass,
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"user_name": user_name,
|
||||
"llm_context": llm_context,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
]
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem generating my prompt: {err}",
|
||||
@ -117,6 +194,16 @@ class OllamaConversationEntity(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
_LOGGER.debug("llm api prompt parts")
|
||||
prompt_parts.append(llm_api.api_prompt)
|
||||
else:
|
||||
_LOGGER.debug("no llm api prompt parts")
|
||||
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
|
||||
message_history = MessageHistory(
|
||||
timestamp=time.monotonic(),
|
||||
messages=[
|
||||
@ -146,32 +233,71 @@ class OllamaConversationEntity(
|
||||
)
|
||||
|
||||
# Get response
|
||||
try:
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
stream=False,
|
||||
# keep_alive requires specifying unit. In this case, seconds
|
||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
stream=False,
|
||||
tools=tools,
|
||||
keep_alive=KEEP_ALIVE_FOREVER,
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
response_message = response["message"]
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=response_message["role"], content=response_message["content"]
|
||||
)
|
||||
)
|
||||
_LOGGER.debug("Response: %s", response)
|
||||
response_message = response["message"]
|
||||
tool_calls = response_message.get("tool_calls")
|
||||
|
||||
def message_convert(response_message: Any) -> ollama.Message:
|
||||
msg = ollama.Message(
|
||||
role=response_message["role"]
|
||||
)
|
||||
if content := response_message.get("content"):
|
||||
msg["content"] = content
|
||||
if tool_calls := response_message.get("tool_calls"):
|
||||
msg["tool_calls"] = tool_calls
|
||||
return msg
|
||||
|
||||
message_history.messages.append(message_convert(response_message))
|
||||
|
||||
if not tool_calls or not llm_api:
|
||||
_LOGGER.debug("tool_calls=%s", tool_calls)
|
||||
_LOGGER.debug("llm_api=%s", llm_api)
|
||||
break
|
||||
|
||||
_LOGGER.debug("Response: %s", response_message.get("content"))
|
||||
for tool_call in tool_calls:
|
||||
tool_input = llm.ToolInput(
|
||||
tool_name=tool_call["function"]["name"],
|
||||
tool_args=json.loads(tool_call["function"]["arguments"]),
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||
)
|
||||
|
||||
try:
|
||||
tool_response = await llm_api.async_call_tool(tool_input)
|
||||
except (HomeAssistantError, vol.Invalid) as e:
|
||||
tool_response = {"error": type(e).__name__}
|
||||
if str(e):
|
||||
tool_response["error_text"] = str(e)
|
||||
|
||||
_LOGGER.debug("Tool response: %s", tool_response)
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role="tool", content=json.dumps(tool_response)
|
||||
)
|
||||
)
|
||||
|
||||
# Create intent response
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
@ -204,62 +330,3 @@ class OllamaConversationEntity(
|
||||
message_history.messages = [
|
||||
message_history.messages[0]
|
||||
] + message_history.messages[drop_index:]
|
||||
|
||||
def _generate_prompt(self, raw_prompt: str) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"ha_language": self.hass.config.language,
|
||||
"exposed_entities": self._get_exposed_entities(),
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
||||
"""Get state list of exposed entities."""
|
||||
area_registry = ar.async_get(self.hass)
|
||||
entity_registry = er.async_get(self.hass)
|
||||
device_registry = dr.async_get(self.hass)
|
||||
|
||||
exposed_entities = []
|
||||
exposed_states = [
|
||||
state
|
||||
for state in self.hass.states.async_all()
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
for state in exposed_states:
|
||||
entity_entry = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
|
||||
if entity_entry is not None:
|
||||
# Add aliases
|
||||
names.extend(entity_entry.aliases)
|
||||
if entity_entry.area_id and (
|
||||
area := area_registry.async_get_area(entity_entry.area_id)
|
||||
):
|
||||
# Entity is in area
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
elif entity_entry.device_id and (
|
||||
device := device_registry.async_get(entity_entry.device_id)
|
||||
):
|
||||
# Check device area
|
||||
if device.area_id and (
|
||||
area := area_registry.async_get_area(device.area_id)
|
||||
):
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
|
||||
exposed_entities.append(
|
||||
ExposedEntity(
|
||||
entity_id=state.entity_id,
|
||||
state=state,
|
||||
names=names,
|
||||
area_names=area_names,
|
||||
)
|
||||
)
|
||||
|
||||
return exposed_entities
|
||||
|
@ -5,7 +5,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import TEST_OPTIONS, TEST_USER_DATA
|
||||
@ -25,6 +27,17 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_with_assist(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with assist."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
|
||||
)
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
|
||||
"""Initialize integration."""
|
||||
|
34
tests/components/ollama/snapshots/test_conversation.ambr
Normal file
34
tests/components/ollama/snapshots/test_conversation.ambr
Normal file
@ -0,0 +1,34 @@
|
||||
# serializer version: 1
|
||||
# name: test_unknown_hass_api
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'response': IntentResponse(
|
||||
card=dict({
|
||||
}),
|
||||
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
|
||||
failed_results=list([
|
||||
]),
|
||||
intent=None,
|
||||
intent_targets=list([
|
||||
]),
|
||||
language='en',
|
||||
matched_states=list([
|
||||
]),
|
||||
reprompt=dict({
|
||||
}),
|
||||
response_type=<IntentResponseType.ERROR: 'error'>,
|
||||
speech=dict({
|
||||
'plain': dict({
|
||||
'extra_data': None,
|
||||
'speech': 'Error preparing LLM API: API non-existing not found',
|
||||
}),
|
||||
}),
|
||||
speech_slots=dict({
|
||||
}),
|
||||
success_results=list([
|
||||
]),
|
||||
unmatched_states=list([
|
||||
]),
|
||||
),
|
||||
})
|
||||
# ---
|
@ -1,24 +1,31 @@
|
||||
"""Tests for the Ollama integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import logging
|
||||
|
||||
from ollama import Message, ResponseError
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation, ollama
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
llm,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||
async def test_chat(
|
||||
@ -124,6 +131,352 @@ async def test_chat(
|
||||
detail_event = trace_events[1]
|
||||
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
||||
|
||||
async def test_template_variables(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template variables work."""
|
||||
context = Context(user_id="12345")
|
||||
mock_user = Mock()
|
||||
mock_user.id = "12345"
|
||||
mock_user.name = "Test User"
|
||||
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
"prompt": (
|
||||
"The user name is {{ user_name }}. "
|
||||
"The user id is {{ llm_context.context.user_id }}."
|
||||
),
|
||||
},
|
||||
)
|
||||
with (
|
||||
patch("ollama.AsyncClient.list"),
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
) as mock_chat,
|
||||
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert "The user name is Test User." in prompt
|
||||
assert "The user id is 12345." in prompt
|
||||
|
||||
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
async def test_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test function call from the assistant."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
def completion_result(*args, messages, tools, **kwargs):
|
||||
_LOGGER.debug("tools=%s", tools)
|
||||
for message in messages:
|
||||
if message["role"] == "tool":
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I have successfully called the function",
|
||||
}
|
||||
}
|
||||
assert tools
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Calling tool",
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"param1": "test_value"}'
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=completion_result,
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 2
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "I have successfully called the function"
|
||||
)
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="ollama",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Test Conversation tracing
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
trace_events = last_trace.get("events", [])
|
||||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||
]
|
||||
|
||||
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
async def test_malformed_function_args(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test getting function args for an unknown function."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
if message["content"].startswith("TOOL_ARGS"):
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I was not able to call the function",
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "TOOL_ARGS unknown_tool",
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=completion_result,
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert mock_tool.async_call.call_count == 0
|
||||
assert mock_chat.call_count == 2
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "I was not able to call the function"
|
||||
)
|
||||
|
||||
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
async def test_malformed_function_call(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test function call that was unrecognized."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
if message["content"].startswith("TOOL_CALL"):
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I was not able to call the function",
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": 'TOOL_CALL name="test_tool", param1="test_value"',
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=completion_result,
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert mock_tool.async_call.call_count == 0
|
||||
assert mock_chat.call_count == 2
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "I was not able to call the function"
|
||||
)
|
||||
|
||||
|
||||
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||
async def test_function_exception(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_assist: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test function call with exception."""
|
||||
agent_id = mock_config_entry_with_assist.entry_id
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{vol.Optional("param1", description="Test parameters"): str}
|
||||
)
|
||||
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
def completion_result(*args, messages, **kwargs):
|
||||
for message in messages:
|
||||
if message["content"].startswith("TOOL_CALL"):
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "There was an error calling the function",
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": 'TOOL_CALL {"name": "test_tool", "parameters": {"param1": "test_value"}}',
|
||||
}
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=completion_result,
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 2
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.speech["plain"]["speech"]
|
||||
== "There was an error calling the function"
|
||||
)
|
||||
mock_tool.async_call.assert_awaited_once_with(
|
||||
hass,
|
||||
llm.ToolInput(
|
||||
tool_name="test_tool",
|
||||
tool_args={"param1": "test_value"},
|
||||
),
|
||||
llm.LLMContext(
|
||||
platform="ollama",
|
||||
context=context,
|
||||
user_prompt="Please call the test function",
|
||||
language="en",
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_unknown_hass_api(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test when we reference an API that no longer exists."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
**mock_config_entry.options,
|
||||
CONF_LLM_HASS_API: "non-existing",
|
||||
},
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result == snapshot
|
||||
|
||||
|
||||
async def test_message_history_trimming(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
|
@ -128,9 +128,9 @@ asyncio.set_event_loop_policy(runner.HassEventLoopPolicy(False))
|
||||
asyncio.set_event_loop_policy = lambda policy: None
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""Register custom pytest options."""
|
||||
parser.addoption("--dburl", action="store", default="sqlite://")
|
||||
#def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
# """Register custom pytest options."""
|
||||
# parser.addoption("--dburl", action="store", default="sqlite://")
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user