mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 22:27:07 +00:00
Ollama tool calling
This commit is contained in:
parent
975cfa6457
commit
8f688ee079
@ -2,25 +2,31 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
|
import voluptuous as vol
|
||||||
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import TemplateError
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
intent,
|
intent,
|
||||||
template,
|
template,
|
||||||
|
llm,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid
|
from homeassistant.util import ulid
|
||||||
@ -34,12 +40,28 @@ from .const import (
|
|||||||
DEFAULT_MAX_HISTORY,
|
DEFAULT_MAX_HISTORY,
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
KEEP_ALIVE_FOREVER,
|
||||||
MAX_HISTORY_SECONDS,
|
MAX_HISTORY_SECONDS,
|
||||||
)
|
)
|
||||||
from .models import ExposedEntity, MessageHistory, MessageRole
|
from .models import MessageHistory, MessageRole
|
||||||
|
|
||||||
|
# Max number of back and forth with the LLM to generate a response
|
||||||
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def _format_tool(
|
||||||
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Format tool specification."""
|
||||||
|
tool_spec = {
|
||||||
|
"name": tool.name,
|
||||||
|
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
|
||||||
|
}
|
||||||
|
if tool.description:
|
||||||
|
tool_spec["description"] = tool.description
|
||||||
|
return tool_spec
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -90,10 +112,55 @@ class OllamaConversationEntity(
|
|||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
settings = {**self.entry.data, **self.entry.options}
|
settings = {**self.entry.data, **self.entry.options}
|
||||||
|
options = self.entry.options
|
||||||
|
|
||||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||||
model = settings[CONF_MODEL]
|
model = settings[CONF_MODEL]
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
|
llm_api: llm.APIInstance | None = None
|
||||||
|
tools: dict[str, dict[str, Any]] | None = None
|
||||||
|
user_name: str | None = None
|
||||||
|
llm_context = llm.LLMContext(
|
||||||
|
platform=DOMAIN,
|
||||||
|
context=user_input.context,
|
||||||
|
user_prompt=user_input.text,
|
||||||
|
language=user_input.language,
|
||||||
|
assistant=conversation.DOMAIN,
|
||||||
|
device_id=user_input.device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug("CONF_LLM_HASS_API=%s", settings.get(CONF_LLM_HASS_API))
|
||||||
|
if settings.get(CONF_LLM_HASS_API):
|
||||||
|
try:
|
||||||
|
llm_api = await llm.async_get_api(
|
||||||
|
self.hass,
|
||||||
|
settings[CONF_LLM_HASS_API],
|
||||||
|
llm_context,
|
||||||
|
)
|
||||||
|
except HomeAssistantError as err:
|
||||||
|
_LOGGER.error("Error getting LLM API: %s", err)
|
||||||
|
intent_response.async_set_error(
|
||||||
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
|
f"Error preparing LLM API: {err}",
|
||||||
|
)
|
||||||
|
return conversation.ConversationResult(
|
||||||
|
response=intent_response, conversation_id=user_input.conversation_id
|
||||||
|
)
|
||||||
|
tools = {
|
||||||
|
tool.name: _format_tool(tool, llm_api.custom_serializer)
|
||||||
|
for tool in llm_api.tools
|
||||||
|
}
|
||||||
|
_LOGGER.debug("tools=%s", tools)
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_input.context
|
||||||
|
and user_input.context.user_id
|
||||||
|
and (
|
||||||
|
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
user_name = user.name
|
||||||
|
|
||||||
# Look up message history
|
# Look up message history
|
||||||
message_history: MessageHistory | None = None
|
message_history: MessageHistory | None = None
|
||||||
@ -102,13 +169,23 @@ class OllamaConversationEntity(
|
|||||||
# New history
|
# New history
|
||||||
#
|
#
|
||||||
# Render prompt and error out early if there's a problem
|
# Render prompt and error out early if there's a problem
|
||||||
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
|
||||||
try:
|
try:
|
||||||
prompt = self._generate_prompt(raw_prompt)
|
prompt_parts = [
|
||||||
_LOGGER.debug("Prompt: %s", prompt)
|
template.Template(
|
||||||
|
llm.BASE_PROMPT
|
||||||
|
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
||||||
|
self.hass,
|
||||||
|
).async_render(
|
||||||
|
{
|
||||||
|
"ha_name": self.hass.config.location_name,
|
||||||
|
"user_name": user_name,
|
||||||
|
"llm_context": llm_context,
|
||||||
|
},
|
||||||
|
parse_result=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
except TemplateError as err:
|
except TemplateError as err:
|
||||||
_LOGGER.error("Error rendering prompt: %s", err)
|
_LOGGER.error("Error rendering prompt: %s", err)
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
f"Sorry, I had a problem generating my prompt: {err}",
|
f"Sorry, I had a problem generating my prompt: {err}",
|
||||||
@ -117,6 +194,16 @@ class OllamaConversationEntity(
|
|||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if llm_api:
|
||||||
|
_LOGGER.debug("llm api prompt parts")
|
||||||
|
prompt_parts.append(llm_api.api_prompt)
|
||||||
|
else:
|
||||||
|
_LOGGER.debug("no llm api prompt parts")
|
||||||
|
|
||||||
|
|
||||||
|
prompt = "\n".join(prompt_parts)
|
||||||
|
_LOGGER.debug("Prompt: %s", prompt)
|
||||||
|
|
||||||
message_history = MessageHistory(
|
message_history = MessageHistory(
|
||||||
timestamp=time.monotonic(),
|
timestamp=time.monotonic(),
|
||||||
messages=[
|
messages=[
|
||||||
@ -146,32 +233,71 @@ class OllamaConversationEntity(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get response
|
# Get response
|
||||||
try:
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
response = await client.chat(
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
model=model,
|
try:
|
||||||
# Make a copy of the messages because we mutate the list later
|
response = await client.chat(
|
||||||
messages=list(message_history.messages),
|
model=model,
|
||||||
stream=False,
|
# Make a copy of the messages because we mutate the list later
|
||||||
# keep_alive requires specifying unit. In this case, seconds
|
messages=list(message_history.messages),
|
||||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
stream=False,
|
||||||
)
|
tools=tools,
|
||||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
keep_alive=KEEP_ALIVE_FOREVER,
|
||||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
)
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||||
intent_response.async_set_error(
|
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent_response.async_set_error(
|
||||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
)
|
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||||
return conversation.ConversationResult(
|
)
|
||||||
response=intent_response, conversation_id=conversation_id
|
return conversation.ConversationResult(
|
||||||
)
|
response=intent_response, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
response_message = response["message"]
|
_LOGGER.debug("Response: %s", response)
|
||||||
message_history.messages.append(
|
response_message = response["message"]
|
||||||
ollama.Message(
|
tool_calls = response_message.get("tool_calls")
|
||||||
role=response_message["role"], content=response_message["content"]
|
|
||||||
)
|
def message_convert(response_message: Any) -> ollama.Message:
|
||||||
)
|
msg = ollama.Message(
|
||||||
|
role=response_message["role"]
|
||||||
|
)
|
||||||
|
if content := response_message.get("content"):
|
||||||
|
msg["content"] = content
|
||||||
|
if tool_calls := response_message.get("tool_calls"):
|
||||||
|
msg["tool_calls"] = tool_calls
|
||||||
|
return msg
|
||||||
|
|
||||||
|
message_history.messages.append(message_convert(response_message))
|
||||||
|
|
||||||
|
if not tool_calls or not llm_api:
|
||||||
|
_LOGGER.debug("tool_calls=%s", tool_calls)
|
||||||
|
_LOGGER.debug("llm_api=%s", llm_api)
|
||||||
|
break
|
||||||
|
|
||||||
|
_LOGGER.debug("Response: %s", response_message.get("content"))
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_input = llm.ToolInput(
|
||||||
|
tool_name=tool_call["function"]["name"],
|
||||||
|
tool_args=json.loads(tool_call["function"]["arguments"]),
|
||||||
|
)
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_response = await llm_api.async_call_tool(tool_input)
|
||||||
|
except (HomeAssistantError, vol.Invalid) as e:
|
||||||
|
tool_response = {"error": type(e).__name__}
|
||||||
|
if str(e):
|
||||||
|
tool_response["error_text"] = str(e)
|
||||||
|
|
||||||
|
_LOGGER.debug("Tool response: %s", tool_response)
|
||||||
|
message_history.messages.append(
|
||||||
|
ollama.Message(
|
||||||
|
role="tool", content=json.dumps(tool_response)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Create intent response
|
# Create intent response
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
@ -204,62 +330,3 @@ class OllamaConversationEntity(
|
|||||||
message_history.messages = [
|
message_history.messages = [
|
||||||
message_history.messages[0]
|
message_history.messages[0]
|
||||||
] + message_history.messages[drop_index:]
|
] + message_history.messages[drop_index:]
|
||||||
|
|
||||||
def _generate_prompt(self, raw_prompt: str) -> str:
|
|
||||||
"""Generate a prompt for the user."""
|
|
||||||
return template.Template(raw_prompt, self.hass).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
"ha_language": self.hass.config.language,
|
|
||||||
"exposed_entities": self._get_exposed_entities(),
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
|
||||||
"""Get state list of exposed entities."""
|
|
||||||
area_registry = ar.async_get(self.hass)
|
|
||||||
entity_registry = er.async_get(self.hass)
|
|
||||||
device_registry = dr.async_get(self.hass)
|
|
||||||
|
|
||||||
exposed_entities = []
|
|
||||||
exposed_states = [
|
|
||||||
state
|
|
||||||
for state in self.hass.states.async_all()
|
|
||||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
for state in exposed_states:
|
|
||||||
entity_entry = entity_registry.async_get(state.entity_id)
|
|
||||||
names = [state.name]
|
|
||||||
area_names = []
|
|
||||||
|
|
||||||
if entity_entry is not None:
|
|
||||||
# Add aliases
|
|
||||||
names.extend(entity_entry.aliases)
|
|
||||||
if entity_entry.area_id and (
|
|
||||||
area := area_registry.async_get_area(entity_entry.area_id)
|
|
||||||
):
|
|
||||||
# Entity is in area
|
|
||||||
area_names.append(area.name)
|
|
||||||
area_names.extend(area.aliases)
|
|
||||||
elif entity_entry.device_id and (
|
|
||||||
device := device_registry.async_get(entity_entry.device_id)
|
|
||||||
):
|
|
||||||
# Check device area
|
|
||||||
if device.area_id and (
|
|
||||||
area := area_registry.async_get_area(device.area_id)
|
|
||||||
):
|
|
||||||
area_names.append(area.name)
|
|
||||||
area_names.extend(area.aliases)
|
|
||||||
|
|
||||||
exposed_entities.append(
|
|
||||||
ExposedEntity(
|
|
||||||
entity_id=state.entity_id,
|
|
||||||
state=state,
|
|
||||||
names=names,
|
|
||||||
area_names=area_names,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return exposed_entities
|
|
||||||
|
@ -5,7 +5,9 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import ollama
|
from homeassistant.components import ollama
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import TEST_OPTIONS, TEST_USER_DATA
|
from . import TEST_OPTIONS, TEST_USER_DATA
|
||||||
@ -25,6 +27,17 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_entry_with_assist(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
) -> MockConfigEntry:
|
||||||
|
"""Mock a config entry with assist."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
|
||||||
|
)
|
||||||
|
return mock_config_entry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
|
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
|
||||||
"""Initialize integration."""
|
"""Initialize integration."""
|
||||||
|
34
tests/components/ollama/snapshots/test_conversation.ambr
Normal file
34
tests/components/ollama/snapshots/test_conversation.ambr
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_unknown_hass_api
|
||||||
|
dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'response': IntentResponse(
|
||||||
|
card=dict({
|
||||||
|
}),
|
||||||
|
error_code=<IntentResponseErrorCode.UNKNOWN: 'unknown'>,
|
||||||
|
failed_results=list([
|
||||||
|
]),
|
||||||
|
intent=None,
|
||||||
|
intent_targets=list([
|
||||||
|
]),
|
||||||
|
language='en',
|
||||||
|
matched_states=list([
|
||||||
|
]),
|
||||||
|
reprompt=dict({
|
||||||
|
}),
|
||||||
|
response_type=<IntentResponseType.ERROR: 'error'>,
|
||||||
|
speech=dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': 'Error preparing LLM API: API non-existing not found',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
speech_slots=dict({
|
||||||
|
}),
|
||||||
|
success_results=list([
|
||||||
|
]),
|
||||||
|
unmatched_states=list([
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
# ---
|
@ -1,24 +1,31 @@
|
|||||||
"""Tests for the Ollama integration."""
|
"""Tests for the Ollama integration."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
import logging
|
||||||
|
|
||||||
from ollama import Message, ResponseError
|
from ollama import Message, ResponseError
|
||||||
import pytest
|
import pytest
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation, ollama
|
from homeassistant.components import conversation, ollama
|
||||||
from homeassistant.components.conversation import trace
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
area_registry as ar,
|
area_registry as ar,
|
||||||
device_registry as dr,
|
device_registry as dr,
|
||||||
entity_registry as er,
|
entity_registry as er,
|
||||||
intent,
|
intent,
|
||||||
|
llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
||||||
async def test_chat(
|
async def test_chat(
|
||||||
@ -124,6 +131,352 @@ async def test_chat(
|
|||||||
detail_event = trace_events[1]
|
detail_event = trace_events[1]
|
||||||
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
||||||
|
|
||||||
|
async def test_template_variables(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
) -> None:
|
||||||
|
"""Test that template variables work."""
|
||||||
|
context = Context(user_id="12345")
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = "12345"
|
||||||
|
mock_user.name = "Test User"
|
||||||
|
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={
|
||||||
|
"prompt": (
|
||||||
|
"The user name is {{ user_name }}. "
|
||||||
|
"The user id is {{ llm_context.context.user_id }}."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with (
|
||||||
|
patch("ollama.AsyncClient.list"),
|
||||||
|
patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||||
|
) as mock_chat,
|
||||||
|
patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user),
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, context, agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
), result
|
||||||
|
|
||||||
|
args = mock_chat.call_args.kwargs
|
||||||
|
prompt = args["messages"][0]["content"]
|
||||||
|
|
||||||
|
assert "The user name is Test User." in prompt
|
||||||
|
assert "The user id is 12345." in prompt
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||||
|
async def test_function_call(
|
||||||
|
mock_get_tools,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test function call from the assistant."""
|
||||||
|
agent_id = mock_config_entry_with_assist.entry_id
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema(
|
||||||
|
{vol.Optional("param1", description="Test parameters"): str}
|
||||||
|
)
|
||||||
|
mock_tool.async_call.return_value = "Test response"
|
||||||
|
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
|
def completion_result(*args, messages, tools, **kwargs):
|
||||||
|
_LOGGER.debug("tools=%s", tools)
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "tool":
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I have successfully called the function",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert tools
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Calling tool",
|
||||||
|
"tool_calls": [{
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"arguments": '{"param1": "test_value"}'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
side_effect=completion_result,
|
||||||
|
) as mock_chat:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
None,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_chat.call_count == 2
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.speech["plain"]["speech"]
|
||||||
|
== "I have successfully called the function"
|
||||||
|
)
|
||||||
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
|
hass,
|
||||||
|
llm.ToolInput(
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "test_value"},
|
||||||
|
),
|
||||||
|
llm.LLMContext(
|
||||||
|
platform="ollama",
|
||||||
|
context=context,
|
||||||
|
user_prompt="Please call the test function",
|
||||||
|
language="en",
|
||||||
|
assistant="conversation",
|
||||||
|
device_id=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test Conversation tracing
|
||||||
|
traces = trace.async_get_traces()
|
||||||
|
assert traces
|
||||||
|
last_trace = traces[-1].as_dict()
|
||||||
|
trace_events = last_trace.get("events", [])
|
||||||
|
assert [event["event_type"] for event in trace_events] == [
|
||||||
|
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||||
|
async def test_malformed_function_args(
|
||||||
|
mock_get_tools,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting function args for an unknown function."""
|
||||||
|
agent_id = mock_config_entry_with_assist.entry_id
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema(
|
||||||
|
{vol.Optional("param1", description="Test parameters"): str}
|
||||||
|
)
|
||||||
|
mock_tool.async_call.return_value = "Test response"
|
||||||
|
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
|
def completion_result(*args, messages, **kwargs):
|
||||||
|
for message in messages:
|
||||||
|
if message["content"].startswith("TOOL_ARGS"):
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I was not able to call the function",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "TOOL_ARGS unknown_tool",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
side_effect=completion_result,
|
||||||
|
) as mock_chat:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
None,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_tool.async_call.call_count == 0
|
||||||
|
assert mock_chat.call_count == 2
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.speech["plain"]["speech"]
|
||||||
|
== "I was not able to call the function"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||||
|
async def test_malformed_function_call(
|
||||||
|
mock_get_tools,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test function call that was unrecognized."""
|
||||||
|
agent_id = mock_config_entry_with_assist.entry_id
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema(
|
||||||
|
{vol.Optional("param1", description="Test parameters"): str}
|
||||||
|
)
|
||||||
|
mock_tool.async_call.return_value = "Test response"
|
||||||
|
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
|
def completion_result(*args, messages, **kwargs):
|
||||||
|
for message in messages:
|
||||||
|
if message["content"].startswith("TOOL_CALL"):
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I was not able to call the function",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": 'TOOL_CALL name="test_tool", param1="test_value"',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
side_effect=completion_result,
|
||||||
|
) as mock_chat:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
None,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_tool.async_call.call_count == 0
|
||||||
|
assert mock_chat.call_count == 2
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.speech["plain"]["speech"]
|
||||||
|
== "I was not able to call the function"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools")
|
||||||
|
async def test_function_exception(
|
||||||
|
mock_get_tools,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_assist: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test function call with exception."""
|
||||||
|
agent_id = mock_config_entry_with_assist.entry_id
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema(
|
||||||
|
{vol.Optional("param1", description="Test parameters"): str}
|
||||||
|
)
|
||||||
|
mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception")
|
||||||
|
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
|
def completion_result(*args, messages, **kwargs):
|
||||||
|
for message in messages:
|
||||||
|
if message["content"].startswith("TOOL_CALL"):
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "There was an error calling the function",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": 'TOOL_CALL {"name": "test_tool", "parameters": {"param1": "test_value"}}',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
side_effect=completion_result,
|
||||||
|
) as mock_chat:
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
None,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_chat.call_count == 2
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.speech["plain"]["speech"]
|
||||||
|
== "There was an error calling the function"
|
||||||
|
)
|
||||||
|
mock_tool.async_call.assert_awaited_once_with(
|
||||||
|
hass,
|
||||||
|
llm.ToolInput(
|
||||||
|
tool_name="test_tool",
|
||||||
|
tool_args={"param1": "test_value"},
|
||||||
|
),
|
||||||
|
llm.LLMContext(
|
||||||
|
platform="ollama",
|
||||||
|
context=context,
|
||||||
|
user_prompt="Please call the test function",
|
||||||
|
language="en",
|
||||||
|
assistant="conversation",
|
||||||
|
device_id=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_unknown_hass_api(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test when we reference an API that no longer exists."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={
|
||||||
|
**mock_config_entry.options,
|
||||||
|
CONF_LLM_HASS_API: "non-existing",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == snapshot
|
||||||
|
|
||||||
|
|
||||||
async def test_message_history_trimming(
|
async def test_message_history_trimming(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
@ -128,9 +128,9 @@ asyncio.set_event_loop_policy(runner.HassEventLoopPolicy(False))
|
|||||||
asyncio.set_event_loop_policy = lambda policy: None
|
asyncio.set_event_loop_policy = lambda policy: None
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
#def pytest_addoption(parser: pytest.Parser) -> None:
|
||||||
"""Register custom pytest options."""
|
# """Register custom pytest options."""
|
||||||
parser.addoption("--dburl", action="store", default="sqlite://")
|
# parser.addoption("--dburl", action="store", default="sqlite://")
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config: pytest.Config) -> None:
|
def pytest_configure(config: pytest.Config) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user