From 3304e27fa3d21a8a11f6032839becf65dc7a5fb9 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 25 Aug 2024 12:22:57 -0700 Subject: [PATCH] Add ollama context window size configuration (#124555) * Add ollama context window size configuration * Set higher max context size --- homeassistant/components/ollama/__init__.py | 2 ++ .../components/ollama/config_flow.py | 12 +++++++ homeassistant/components/ollama/const.py | 5 +++ .../components/ollama/conversation.py | 3 ++ homeassistant/components/ollama/strings.json | 4 ++- tests/components/ollama/conftest.py | 13 ++++++-- tests/components/ollama/test_config_flow.py | 7 ++++- tests/components/ollama/test_conversation.py | 31 +++++++++++++++++++ 8 files changed, 73 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index 2286a2c7b75..2ad389c55c3 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -18,6 +18,7 @@ from .const import ( CONF_KEEP_ALIVE, CONF_MAX_HISTORY, CONF_MODEL, + CONF_NUM_CTX, CONF_PROMPT, DEFAULT_TIMEOUT, DOMAIN, @@ -30,6 +31,7 @@ __all__ = [ "CONF_PROMPT", "CONF_MODEL", "CONF_MAX_HISTORY", + "CONF_NUM_CTX", "CONF_KEEP_ALIVE", "DOMAIN", ] diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index bcdd6e06f48..6b516d67138 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -38,12 +38,16 @@ from .const import ( CONF_KEEP_ALIVE, CONF_MAX_HISTORY, CONF_MODEL, + CONF_NUM_CTX, CONF_PROMPT, DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, DEFAULT_MODEL, + DEFAULT_NUM_CTX, DEFAULT_TIMEOUT, DOMAIN, + MAX_NUM_CTX, + MIN_NUM_CTX, MODEL_NAMES, ) @@ -255,6 +259,14 @@ def ollama_config_option_schema( description={"suggested_value": options.get(CONF_LLM_HASS_API)}, default="none", ): SelectSelector(SelectSelectorConfig(options=hass_apis)), + vol.Optional( + CONF_NUM_CTX, + description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, + ): NumberSelector( + NumberSelectorConfig( + min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX + ) + ), vol.Optional( CONF_MAX_HISTORY, description={ diff --git a/homeassistant/components/ollama/const.py b/homeassistant/components/ollama/const.py index 97c4f1186fc..6152b223d6d 100644 --- a/homeassistant/components/ollama/const.py +++ b/homeassistant/components/ollama/const.py @@ -11,6 +11,11 @@ DEFAULT_KEEP_ALIVE = -1 # seconds. -1 = indefinite, 0 = never KEEP_ALIVE_FOREVER = -1 DEFAULT_TIMEOUT = 5.0 # seconds +CONF_NUM_CTX = "num_ctx" +DEFAULT_NUM_CTX = 8192 +MIN_NUM_CTX = 2048 +MAX_NUM_CTX = 131072 + CONF_MAX_HISTORY = "max_history" DEFAULT_MAX_HISTORY = 20 diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index c0423b258f0..1a91c790d27 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -26,9 +26,11 @@ from .const import ( CONF_KEEP_ALIVE, CONF_MAX_HISTORY, CONF_MODEL, + CONF_NUM_CTX, CONF_PROMPT, DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, + DEFAULT_NUM_CTX, DOMAIN, MAX_HISTORY_SECONDS, ) @@ -263,6 +265,7 @@ class OllamaConversationEntity( stream=False, # keep_alive requires specifying unit. In this case, seconds keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", + options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, ) except (ollama.RequestError, ollama.ResponseError) as err: _LOGGER.error("Unexpected error talking to Ollama server: %s", err) diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index 2366ecd0848..c307f160228 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -27,11 +27,13 @@ "prompt": "Instructions", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", "max_history": "Max history messages", + "num_ctx": "Context window size", "keep_alive": "Keep alive" }, "data_description": { "prompt": "Instruct how the LLM should respond. This can be a template.", - "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never." + "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.", + "num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities." } } } diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index b28b8850cd5..7658d1cbfab 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -1,5 +1,6 @@ """Tests Ollama integration.""" +from typing import Any from unittest.mock import patch import pytest @@ -16,12 +17,20 @@ from tests.common import MockConfigEntry @pytest.fixture -def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: +def mock_config_entry_options() -> dict[str, Any]: + """Fixture for configuration entry options.""" + return TEST_OPTIONS + + +@pytest.fixture +def mock_config_entry( + hass: HomeAssistant, mock_config_entry_options: dict[str, Any] +) -> MockConfigEntry: """Mock a config entry.""" entry = MockConfigEntry( domain=ollama.DOMAIN, data=TEST_USER_DATA, - options=TEST_OPTIONS, + options=mock_config_entry_options, ) entry.add_to_hass(hass) return entry diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py index b1b74197139..7755f2208b4 100644 --- a/tests/components/ollama/test_config_flow.py +++ b/tests/components/ollama/test_config_flow.py @@ -164,13 +164,18 @@ async def test_options( ) options = await hass.config_entries.options.async_configure( options_flow["flow_id"], - {ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100}, + { + ollama.CONF_PROMPT: "test prompt", + ollama.CONF_MAX_HISTORY: 100, + ollama.CONF_NUM_CTX: 32768, + }, ) await hass.async_block_till_done() assert options["type"] is FlowResultType.CREATE_ENTRY assert options["data"] == { ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100, + ollama.CONF_NUM_CTX: 32768, } diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index cb56b398342..f10805e747d 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -578,3 +578,34 @@ async def test_conversation_agent_with_assist( state.attributes[ATTR_SUPPORTED_FEATURES] == conversation.ConversationEntityFeature.CONTROL ) + + +@pytest.mark.parametrize( + ("mock_config_entry_options", "expected_options"), + [ + ({}, {"num_ctx": 8192}), + ({"num_ctx": 16384}, {"num_ctx": 16384}), + ], +) +async def test_options( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + expected_options: dict[str, Any], +) -> None: + """Test that options are passed correctly to ollama client.""" + with patch( + "ollama.AsyncClient.chat", + return_value={"message": {"role": "assistant", "content": "test response"}}, + ) as mock_chat: + await conversation.async_converse( + hass, + "test message", + None, + Context(), + agent_id="conversation.mock_title", + ) + + assert mock_chat.call_count == 1 + args = mock_chat.call_args.kwargs + assert args.get("options") == expected_options