diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index ccc7b9bdecc..797adfc9f41 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -2,25 +2,31 @@ from __future__ import annotations +from collections.abc import Callable +import json import logging import time -from typing import Literal +from typing import Any, Literal import ollama +import voluptuous as vol +from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.components.conversation import trace from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import TemplateError +from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.helpers import ( area_registry as ar, device_registry as dr, entity_registry as er, intent, template, + llm, ) from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid @@ -34,12 +40,28 @@ from .const import ( DEFAULT_MAX_HISTORY, DEFAULT_PROMPT, DOMAIN, + KEEP_ALIVE_FOREVER, MAX_HISTORY_SECONDS, ) -from .models import ExposedEntity, MessageHistory, MessageRole +from .models import MessageHistory, MessageRole + +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 _LOGGER = logging.getLogger(__name__) +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> dict[str, Any]: + """Format tool specification.""" + tool_spec = { + "name": tool.name, + "parameters": convert(tool.parameters, custom_serializer=custom_serializer), + } + if tool.description: + tool_spec["description"] = tool.description + return tool_spec + async def async_setup_entry( hass: HomeAssistant, @@ -90,10 +112,55 @@ class OllamaConversationEntity( ) -> conversation.ConversationResult: """Process a sentence.""" settings = {**self.entry.data, **self.entry.options} + options = self.entry.options client = self.hass.data[DOMAIN][self.entry.entry_id] conversation_id = user_input.conversation_id or ulid.ulid_now() model = settings[CONF_MODEL] + intent_response = intent.IntentResponse(language=user_input.language) + llm_api: llm.APIInstance | None = None + tools: dict[str, dict[str, Any]] | None = None + user_name: str | None = None + llm_context = llm.LLMContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + + _LOGGER.debug("CONF_LLM_HASS_API=%s", settings.get(CONF_LLM_HASS_API)) + if settings.get(CONF_LLM_HASS_API): + try: + llm_api = await llm.async_get_api( + self.hass, + settings[CONF_LLM_HASS_API], + llm_context, + ) + except HomeAssistantError as err: + _LOGGER.error("Error getting LLM API: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Error preparing LLM API: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=user_input.conversation_id + ) + tools = { + tool.name: _format_tool(tool, llm_api.custom_serializer) + for tool in llm_api.tools + } + _LOGGER.debug("tools=%s", tools) + + if ( + user_input.context + and user_input.context.user_id + and ( + user := await self.hass.auth.async_get_user(user_input.context.user_id) + ) + ): + user_name = user.name # Look up message history message_history: MessageHistory | None = None @@ -102,13 +169,23 @@ class OllamaConversationEntity( # New history # # Render prompt and error out early if there's a problem - raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT) try: - prompt = self._generate_prompt(raw_prompt) - _LOGGER.debug("Prompt: %s", prompt) + prompt_parts = [ + template.Template( + llm.BASE_PROMPT + + settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT), + self.hass, + ).async_render( + { + "ha_name": self.hass.config.location_name, + "user_name": user_name, + "llm_context": llm_context, + }, + parse_result=False, + ) + ] except TemplateError as err: _LOGGER.error("Error rendering prompt: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem generating my prompt: {err}", @@ -117,6 +194,16 @@ class OllamaConversationEntity( response=intent_response, conversation_id=conversation_id ) + if llm_api: + _LOGGER.debug("llm api prompt parts") + prompt_parts.append(llm_api.api_prompt) + else: + _LOGGER.debug("no llm api prompt parts") + + + prompt = "\n".join(prompt_parts) + _LOGGER.debug("Prompt: %s", prompt) + message_history = MessageHistory( timestamp=time.monotonic(), messages=[ @@ -146,32 +233,71 @@ class OllamaConversationEntity( ) # Get response - try: - response = await client.chat( - model=model, - # Make a copy of the messages because we mutate the list later - messages=list(message_history.messages), - stream=False, - # keep_alive requires specifying unit. In this case, seconds - keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", - ) - except (ollama.RequestError, ollama.ResponseError) as err: - _LOGGER.error("Unexpected error talking to Ollama server: %s", err) - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to the Ollama server: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) + # To prevent infinite loops, we limit the number of iterations + for _iteration in range(MAX_TOOL_ITERATIONS): + try: + response = await client.chat( + model=model, + # Make a copy of the messages because we mutate the list later + messages=list(message_history.messages), + stream=False, + tools=tools, + keep_alive=KEEP_ALIVE_FOREVER, + ) + except (ollama.RequestError, ollama.ResponseError) as err: + _LOGGER.error("Unexpected error talking to Ollama server: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to the Ollama server: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) - response_message = response["message"] - message_history.messages.append( - ollama.Message( - role=response_message["role"], content=response_message["content"] - ) - ) + _LOGGER.debug("Response: %s", response) + response_message = response["message"] + tool_calls = response_message.get("tool_calls") + + def message_convert(response_message: Any) -> ollama.Message: + msg = ollama.Message( + role=response_message["role"] + ) + if content := response_message.get("content"): + msg["content"] = content + if tool_calls := response_message.get("tool_calls"): + msg["tool_calls"] = tool_calls + return msg + + message_history.messages.append(message_convert(response_message)) + + if not tool_calls or not llm_api: + _LOGGER.debug("tool_calls=%s", tool_calls) + _LOGGER.debug("llm_api=%s", llm_api) + break + + _LOGGER.debug("Response: %s", response_message.get("content")) + for tool_call in tool_calls: + tool_input = llm.ToolInput( + tool_name=tool_call["function"]["name"], + tool_args=json.loads(tool_call["function"]["arguments"]), + ) + _LOGGER.debug( + "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args + ) + + try: + tool_response = await llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + tool_response = {"error": type(e).__name__} + if str(e): + tool_response["error_text"] = str(e) + + _LOGGER.debug("Tool response: %s", tool_response) + message_history.messages.append( + ollama.Message( + role="tool", content=json.dumps(tool_response) + ) + ) # Create intent response intent_response = intent.IntentResponse(language=user_input.language) @@ -204,62 +330,3 @@ class OllamaConversationEntity( message_history.messages = [ message_history.messages[0] ] + message_history.messages[drop_index:] - - def _generate_prompt(self, raw_prompt: str) -> str: - """Generate a prompt for the user.""" - return template.Template(raw_prompt, self.hass).async_render( - { - "ha_name": self.hass.config.location_name, - "ha_language": self.hass.config.language, - "exposed_entities": self._get_exposed_entities(), - }, - parse_result=False, - ) - - def _get_exposed_entities(self) -> list[ExposedEntity]: - """Get state list of exposed entities.""" - area_registry = ar.async_get(self.hass) - entity_registry = er.async_get(self.hass) - device_registry = dr.async_get(self.hass) - - exposed_entities = [] - exposed_states = [ - state - for state in self.hass.states.async_all() - if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id) - ] - - for state in exposed_states: - entity_entry = entity_registry.async_get(state.entity_id) - names = [state.name] - area_names = [] - - if entity_entry is not None: - # Add aliases - names.extend(entity_entry.aliases) - if entity_entry.area_id and ( - area := area_registry.async_get_area(entity_entry.area_id) - ): - # Entity is in area - area_names.append(area.name) - area_names.extend(area.aliases) - elif entity_entry.device_id and ( - device := device_registry.async_get(entity_entry.device_id) - ): - # Check device area - if device.area_id and ( - area := area_registry.async_get_area(device.area_id) - ): - area_names.append(area.name) - area_names.extend(area.aliases) - - exposed_entities.append( - ExposedEntity( - entity_id=state.entity_id, - state=state, - names=names, - area_names=area_names, - ) - ) - - return exposed_entities diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index db1689bd416..0355a13eba7 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -5,7 +5,9 @@ from unittest.mock import patch import pytest from homeassistant.components import ollama +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from homeassistant.setup import async_setup_component from . import TEST_OPTIONS, TEST_USER_DATA @@ -25,6 +27,17 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: return entry +@pytest.fixture +def mock_config_entry_with_assist( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> MockConfigEntry: + """Mock a config entry with assist.""" + hass.config_entries.async_update_entry( + mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + ) + return mock_config_entry + + @pytest.fixture async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry): """Initialize integration.""" diff --git a/tests/components/ollama/snapshots/test_conversation.ambr b/tests/components/ollama/snapshots/test_conversation.ambr new file mode 100644 index 00000000000..e4dd7cd00bb --- /dev/null +++ b/tests/components/ollama/snapshots/test_conversation.ambr @@ -0,0 +1,34 @@ +# serializer version: 1 +# name: test_unknown_hass_api + dict({ + 'conversation_id': None, + 'response': IntentResponse( + card=dict({ + }), + error_code=, + failed_results=list([ + ]), + intent=None, + intent_targets=list([ + ]), + language='en', + matched_states=list([ + ]), + reprompt=dict({ + }), + response_type=, + speech=dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Error preparing LLM API: API non-existing not found', + }), + }), + speech_slots=dict({ + }), + success_results=list([ + ]), + unmatched_states=list([ + ]), + ), + }) +# --- diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index b6f0be3c414..1161f46cec5 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -1,24 +1,31 @@ """Tests for the Ollama integration.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch +import logging from ollama import Message, ResponseError import pytest +from syrupy.assertion import SnapshotAssertion +import voluptuous as vol from homeassistant.components import conversation, ollama from homeassistant.components.conversation import trace from homeassistant.components.homeassistant.exposed_entities import async_expose_entity +from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL from homeassistant.core import Context, HomeAssistant +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( area_registry as ar, device_registry as dr, entity_registry as er, intent, + llm, ) from tests.common import MockConfigEntry +_LOGGER = logging.getLogger(__name__) @pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) async def test_chat( @@ -124,6 +131,352 @@ async def test_chat( detail_event = trace_events[1] assert "The current time is" in detail_event["data"]["messages"][0]["content"] +async def test_template_variables( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test that template variables work.""" + context = Context(user_id="12345") + mock_user = Mock() + mock_user.id = "12345" + mock_user.name = "Test User" + + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + "prompt": ( + "The user name is {{ user_name }}. " + "The user id is {{ llm_context.context.user_id }}." + ), + }, + ) + with ( + patch("ollama.AsyncClient.list"), + patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ) as mock_chat, + patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user), + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + result = await conversation.async_converse( + hass, "hello", None, context, agent_id=mock_config_entry.entry_id + ) + + assert ( + result.response.response_type == intent.IntentResponseType.ACTION_DONE + ), result + + args = mock_chat.call_args.kwargs + prompt = args["messages"][0]["content"] + + assert "The user name is Test User." in prompt + assert "The user id is 12345." in prompt + + +@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +async def test_function_call( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test function call from the assistant.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.return_value = "Test response" + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, tools, **kwargs): + _LOGGER.debug("tools=%s", tools) + for message in messages: + if message["role"] == "tool": + return { + "message": { + "role": "assistant", + "content": "I have successfully called the function", + } + } + assert tools + return { + "message": { + "role": "assistant", + "content": "Calling tool", + "tool_calls": [{ + "function": { + "name": "test_tool", + "arguments": '{"param1": "test_value"}' + } + }] + } + } + + with patch( + "ollama.AsyncClient.chat", + side_effect=completion_result, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert mock_chat.call_count == 2 + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "I have successfully called the function" + ) + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "test_value"}, + ), + llm.LLMContext( + platform="ollama", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + device_id=None, + ), + ) + + # Test Conversation tracing + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + trace_events = last_trace.get("events", []) + assert [event["event_type"] for event in trace_events] == [ + trace.ConversationTraceEventType.ASYNC_PROCESS, + trace.ConversationTraceEventType.AGENT_DETAIL, + trace.ConversationTraceEventType.LLM_TOOL_CALL, + ] + + +@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +async def test_malformed_function_args( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test getting function args for an unknown function.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.return_value = "Test response" + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, **kwargs): + for message in messages: + if message["content"].startswith("TOOL_ARGS"): + return { + "message": { + "role": "assistant", + "content": "I was not able to call the function", + } + } + + return { + "message": { + "role": "assistant", + "content": "TOOL_ARGS unknown_tool", + } + } + + with patch( + "ollama.AsyncClient.chat", + side_effect=completion_result, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert mock_tool.async_call.call_count == 0 + assert mock_chat.call_count == 2 + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "I was not able to call the function" + ) + + +@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +async def test_malformed_function_call( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test function call that was unrecognized.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.return_value = "Test response" + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, **kwargs): + for message in messages: + if message["content"].startswith("TOOL_CALL"): + return { + "message": { + "role": "assistant", + "content": "I was not able to call the function", + } + } + + return { + "message": { + "role": "assistant", + "content": 'TOOL_CALL name="test_tool", param1="test_value"', + } + } + + with patch( + "ollama.AsyncClient.chat", + side_effect=completion_result, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert mock_tool.async_call.call_count == 0 + assert mock_chat.call_count == 2 + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "I was not able to call the function" + ) + + +@patch("homeassistant.components.ollama.conversation.llm.AssistAPI._async_get_tools") +async def test_function_exception( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test function call with exception.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception") + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, **kwargs): + for message in messages: + if message["content"].startswith("TOOL_CALL"): + return { + "message": { + "role": "assistant", + "content": "There was an error calling the function", + } + } + + return { + "message": { + "role": "assistant", + "content": 'TOOL_CALL {"name": "test_tool", "parameters": {"param1": "test_value"}}', + } + } + + with patch( + "ollama.AsyncClient.chat", + side_effect=completion_result, + ) as mock_chat: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert mock_chat.call_count == 2 + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.speech["plain"]["speech"] + == "There was an error calling the function" + ) + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "test_value"}, + ), + llm.LLMContext( + platform="ollama", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + device_id=None, + ), + ) + + +async def test_unknown_hass_api( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + snapshot: SnapshotAssertion, + mock_init_component, +) -> None: + """Test when we reference an API that no longer exists.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + **mock_config_entry.options, + CONF_LLM_HASS_API: "non-existing", + }, + ) + + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result == snapshot + async def test_message_history_trimming( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component diff --git a/tests/conftest.py b/tests/conftest.py index 935ceffa108..2826d28a530 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,9 +128,9 @@ asyncio.set_event_loop_policy(runner.HassEventLoopPolicy(False)) asyncio.set_event_loop_policy = lambda policy: None -def pytest_addoption(parser: pytest.Parser) -> None: - """Register custom pytest options.""" - parser.addoption("--dburl", action="store", default="sqlite://") +#def pytest_addoption(parser: pytest.Parser) -> None: +# """Register custom pytest options.""" +# parser.addoption("--dburl", action="store", default="sqlite://") def pytest_configure(config: pytest.Config) -> None: