diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index c828ee0af9f..90d2012766d 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -8,11 +8,16 @@ import logging import httpx import ollama -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_URL, Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, +) +from homeassistant.helpers.typing import ConfigType from homeassistant.util.ssl import get_default_context from .const import ( @@ -42,8 +47,16 @@ __all__ = [ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) PLATFORMS = (Platform.CONVERSATION,) +type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient] -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up Ollama.""" + await async_migrate_integration(hass) + return True + + +async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool: """Set up Ollama from a config entry.""" settings = {**entry.data, **entry.options} client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context()) @@ -53,8 +66,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: except (TimeoutError, httpx.ConnectError) as err: raise ConfigEntryNotReady(err) from err - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client - + entry.runtime_data = client await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True @@ -63,5 +75,69 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Ollama.""" if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return False - hass.data[DOMAIN].pop(entry.entry_id) return True + + +async def async_migrate_integration(hass: HomeAssistant) -> None: + """Migrate integration entry structure.""" + + entries = hass.config_entries.async_entries(DOMAIN) + if not any(entry.version == 1 for entry in entries): + return + + api_keys_entries: dict[str, ConfigEntry] = {} + entity_registry = er.async_get(hass) + device_registry = dr.async_get(hass) + + for entry in entries: + use_existing = False + subentry = ConfigSubentry( + data=entry.options, + subentry_type="conversation", + title=entry.title, + unique_id=None, + ) + if entry.data[CONF_URL] not in api_keys_entries: + use_existing = True + api_keys_entries[entry.data[CONF_URL]] = entry + + parent_entry = api_keys_entries[entry.data[CONF_URL]] + + hass.config_entries.async_add_subentry(parent_entry, subentry) + conversation_entity = entity_registry.async_get_entity_id( + "conversation", + DOMAIN, + entry.entry_id, + ) + if conversation_entity is not None: + entity_registry.async_update_entity( + conversation_entity, + config_entry_id=parent_entry.entry_id, + config_subentry_id=subentry.subentry_id, + new_unique_id=subentry.subentry_id, + ) + + device = device_registry.async_get_device( + identifiers={(DOMAIN, entry.entry_id)} + ) + if device is not None: + device_registry.async_update_device( + device.id, + new_identifiers={(DOMAIN, subentry.subentry_id)}, + add_config_subentry_id=subentry.subentry_id, + add_config_entry_id=parent_entry.entry_id, + ) + if parent_entry.entry_id != entry.entry_id: + device_registry.async_update_device( + device.id, + remove_config_entry_id=entry.entry_id, + ) + + if not use_existing: + await hass.config_entries.async_remove(entry.entry_id) + else: + hass.config_entries.async_update_entry( + entry, + options={}, + version=2, + ) diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index b94a0fc621d..58b557549e1 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -14,12 +14,14 @@ import voluptuous as vol from homeassistant.config_entries import ( ConfigEntry, + ConfigEntryState, ConfigFlow, ConfigFlowResult, - OptionsFlow, + ConfigSubentryFlow, + SubentryFlowResult, ) -from homeassistant.const import CONF_LLM_HASS_API, CONF_URL -from homeassistant.core import HomeAssistant +from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import llm from homeassistant.helpers.selector import ( BooleanSelector, @@ -70,7 +72,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema( class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for Ollama.""" - VERSION = 1 + VERSION = 2 def __init__(self) -> None: """Initialize config flow.""" @@ -94,6 +96,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): errors = {} + self._async_abort_entries_match({CONF_URL: self.url}) + try: self.client = ollama.AsyncClient( host=self.url, verify=get_default_context() @@ -146,8 +150,16 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): return await self.async_step_download() return self.async_create_entry( - title=_get_title(self.model), + title=self.url, data={CONF_URL: self.url, CONF_MODEL: self.model}, + subentries=[ + { + "subentry_type": "conversation", + "data": {}, + "title": _get_title(self.model), + "unique_id": None, + } + ], ) async def async_step_download( @@ -189,6 +201,14 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_create_entry( title=_get_title(self.model), data={CONF_URL: self.url, CONF_MODEL: self.model}, + subentries=[ + { + "subentry_type": "conversation", + "data": {}, + "title": _get_title(self.model), + "unique_id": None, + } + ], ) async def async_step_failed( @@ -197,41 +217,62 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): """Step after model downloading has failed.""" return self.async_abort(reason="download_failed") - @staticmethod - def async_get_options_flow( - config_entry: ConfigEntry, - ) -> OptionsFlow: - """Create the options flow.""" - return OllamaOptionsFlow(config_entry) + @classmethod + @callback + def async_get_supported_subentry_types( + cls, config_entry: ConfigEntry + ) -> dict[str, type[ConfigSubentryFlow]]: + """Return subentries supported by this integration.""" + return {"conversation": ConversationSubentryFlowHandler} -class OllamaOptionsFlow(OptionsFlow): - """Ollama options flow.""" +class ConversationSubentryFlowHandler(ConfigSubentryFlow): + """Flow for managing conversation subentries.""" - def __init__(self, config_entry: ConfigEntry) -> None: - """Initialize options flow.""" - self.url: str = config_entry.data[CONF_URL] - self.model: str = config_entry.data[CONF_MODEL] + @property + def _is_new(self) -> bool: + """Return if this is a new subentry.""" + return self.source == "user" - async def async_step_init( + async def async_step_set_options( self, user_input: dict[str, Any] | None = None - ) -> ConfigFlowResult: - """Manage the options.""" - if user_input is not None: + ) -> SubentryFlowResult: + """Set conversation options.""" + # abort if entry is not loaded + if self._get_entry().state != ConfigEntryState.LOADED: + return self.async_abort(reason="entry_not_loaded") + + errors: dict[str, str] = {} + + if user_input is None: + if self._is_new: + options = {} + else: + options = self._get_reconfigure_subentry().data.copy() + + elif self._is_new: return self.async_create_entry( - title=_get_title(self.model), data=user_input + title=user_input.pop(CONF_NAME), + data=user_input, + ) + else: + return self.async_update_and_abort( + self._get_entry(), + self._get_reconfigure_subentry(), + data=user_input, ) - options: Mapping[str, Any] = self.config_entry.options or {} - schema = ollama_config_option_schema(self.hass, options) + schema = ollama_config_option_schema(self.hass, self._is_new, options) return self.async_show_form( - step_id="init", - data_schema=vol.Schema(schema), + step_id="set_options", data_schema=vol.Schema(schema), errors=errors ) + async_step_user = async_step_set_options + async_step_reconfigure = async_step_set_options + def ollama_config_option_schema( - hass: HomeAssistant, options: Mapping[str, Any] + hass: HomeAssistant, is_new: bool, options: Mapping[str, Any] ) -> dict: """Ollama options schema.""" hass_apis: list[SelectOptionDict] = [ @@ -242,54 +283,72 @@ def ollama_config_option_schema( for api in llm.async_get_apis(hass) ] - return { - vol.Optional( - CONF_PROMPT, - description={ - "suggested_value": options.get( - CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + if is_new: + schema: dict[vol.Required | vol.Optional, Any] = { + vol.Required(CONF_NAME, default="Ollama Conversation"): str, + } + else: + schema = {} + + schema.update( + { + vol.Optional( + CONF_PROMPT, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + ) + }, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + ): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), + vol.Optional( + CONF_NUM_CTX, + description={ + "suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX) + }, + ): NumberSelector( + NumberSelectorConfig( + min=MIN_NUM_CTX, + max=MAX_NUM_CTX, + step=1, + mode=NumberSelectorMode.BOX, ) - }, - ): TemplateSelector(), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": options.get(CONF_LLM_HASS_API)}, - ): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), - vol.Optional( - CONF_NUM_CTX, - description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, - ): NumberSelector( - NumberSelectorConfig( - min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX - ) - ), - vol.Optional( - CONF_MAX_HISTORY, - description={ - "suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY) - }, - ): NumberSelector( - NumberSelectorConfig( - min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX - ) - ), - vol.Optional( - CONF_KEEP_ALIVE, - description={ - "suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE) - }, - ): NumberSelector( - NumberSelectorConfig( - min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX - ) - ), - vol.Optional( - CONF_THINK, - description={ - "suggested_value": options.get("think", DEFAULT_THINK), - }, - ): BooleanSelector(), - } + ), + vol.Optional( + CONF_MAX_HISTORY, + description={ + "suggested_value": options.get( + CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY + ) + }, + ): NumberSelector( + NumberSelectorConfig( + min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX + ) + ), + vol.Optional( + CONF_KEEP_ALIVE, + description={ + "suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE) + }, + ): NumberSelector( + NumberSelectorConfig( + min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX + ) + ), + vol.Optional( + CONF_THINK, + description={ + "suggested_value": options.get("think", DEFAULT_THINK), + }, + ): BooleanSelector(), + } + ) + + return schema def _get_title(model: str) -> str: diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index 1717d0b24b2..beedb61f942 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, AsyncIterator, Callable import json import logging from typing import Any, Literal @@ -11,13 +11,14 @@ import ollama from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation -from homeassistant.config_entries import ConfigEntry +from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import intent, llm +from homeassistant.helpers import device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from . import OllamaConfigEntry from .const import ( CONF_KEEP_ALIVE, CONF_MAX_HISTORY, @@ -40,12 +41,18 @@ _LOGGER = logging.getLogger(__name__) async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: OllamaConfigEntry, async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up conversation entities.""" - agent = OllamaConversationEntity(config_entry) - async_add_entities([agent]) + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "conversation": + continue + + async_add_entities( + [OllamaConversationEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) def _format_tool( @@ -130,7 +137,7 @@ def _convert_content( async def _transform_stream( - result: AsyncGenerator[ollama.Message], + result: AsyncIterator[ollama.ChatResponse], ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: """Transform the response stream into HA format. @@ -174,17 +181,22 @@ class OllamaConversationEntity( ): """Ollama conversation agent.""" - _attr_has_entity_name = True _attr_supports_streaming = True - def __init__(self, entry: ConfigEntry) -> None: + def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" self.entry = entry - - # conversation id -> message history - self._attr_name = entry.title - self._attr_unique_id = entry.entry_id - if self.entry.options.get(CONF_LLM_HASS_API): + self.subentry = subentry + self._attr_name = subentry.title + self._attr_unique_id = subentry.subentry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, subentry.subentry_id)}, + name=subentry.title, + manufacturer="Ollama", + model=entry.data[CONF_MODEL], + entry_type=dr.DeviceEntryType.SERVICE, + ) + if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL ) @@ -216,7 +228,7 @@ class OllamaConversationEntity( chat_log: conversation.ChatLog, ) -> conversation.ConversationResult: """Call the API.""" - settings = {**self.entry.data, **self.entry.options} + settings = {**self.entry.data, **self.subentry.data} try: await chat_log.async_provide_llm_data( @@ -248,9 +260,9 @@ class OllamaConversationEntity( chat_log: conversation.ChatLog, ) -> None: """Generate an answer for the chat log.""" - settings = {**self.entry.data, **self.entry.options} + settings = {**self.entry.data, **self.subentry.data} - client = self.hass.data[DOMAIN][self.entry.entry_id] + client = self.entry.runtime_data model = settings[CONF_MODEL] tools: list[dict[str, Any]] | None = None diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index c60b0ef7ebd..74a5eaff454 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -12,7 +12,8 @@ } }, "abort": { - "download_failed": "Model downloading failed" + "download_failed": "Model downloading failed", + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" }, "error": { "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", @@ -22,23 +23,35 @@ "download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details." } }, - "options": { - "step": { - "init": { - "data": { - "prompt": "Instructions", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", - "max_history": "Max history messages", - "num_ctx": "Context window size", - "keep_alive": "Keep alive", - "think": "Think before responding" - }, - "data_description": { - "prompt": "Instruct how the LLM should respond. This can be a template.", - "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.", - "num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.", - "think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency." + "config_subentries": { + "conversation": { + "initiate_flow": { + "user": "Add conversation agent", + "reconfigure": "Reconfigure conversation agent" + }, + "entry_type": "Conversation agent", + "step": { + "set_options": { + "data": { + "name": "[%key:common::config_flow::data::name%]", + "prompt": "Instructions", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", + "max_history": "Max history messages", + "num_ctx": "Context window size", + "keep_alive": "Keep alive", + "think": "Think before responding" + }, + "data_description": { + "prompt": "Instruct how the LLM should respond. This can be a template.", + "keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.", + "num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.", + "think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency." + } } + }, + "abort": { + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", + "entry_not_loaded": "Cannot add things while the configuration is disabled." } } } diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index 7658d1cbfab..c99f586a5d4 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -30,7 +30,15 @@ def mock_config_entry( entry = MockConfigEntry( domain=ollama.DOMAIN, data=TEST_USER_DATA, - options=mock_config_entry_options, + version=2, + subentries_data=[ + { + "data": mock_config_entry_options, + "subentry_type": "conversation", + "title": "Ollama Conversation", + "unique_id": None, + } + ], ) entry.add_to_hass(hass) return entry @@ -41,8 +49,10 @@ def mock_config_entry_with_assist( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> MockConfigEntry: """Mock a config entry with assist.""" - hass.config_entries.async_update_entry( - mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + hass.config_entries.async_update_subentry( + mock_config_entry, + next(iter(mock_config_entry.subentries.values())), + data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, ) return mock_config_entry diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py index 34282f25e90..4b78df9bce2 100644 --- a/tests/components/ollama/test_config_flow.py +++ b/tests/components/ollama/test_config_flow.py @@ -63,6 +63,37 @@ async def test_form(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 +async def test_duplicate_entry(hass: HomeAssistant) -> None: + """Test we abort on duplicate config entry.""" + MockConfigEntry( + domain=ollama.DOMAIN, + data={ + ollama.CONF_URL: "http://localhost:11434", + ollama.CONF_MODEL: "test_model", + }, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert not result["errors"] + + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + return_value={"models": [{"model": "test_model"}]}, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + ollama.CONF_URL: "http://localhost:11434", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "already_configured" + + async def test_form_need_download(hass: HomeAssistant) -> None: """Test flow when a model needs to be downloaded.""" # Pretend we already set up a config entry. @@ -155,14 +186,21 @@ async def test_form_need_download(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 -async def test_options( +async def test_subentry_options( hass: HomeAssistant, mock_config_entry, mock_init_component ) -> None: - """Test the options form.""" - options_flow = await hass.config_entries.options.async_init( - mock_config_entry.entry_id + """Test the subentry options form.""" + subentry = next(iter(mock_config_entry.subentries.values())) + + # Test reconfiguration + options_flow = await mock_config_entry.start_subentry_reconfigure_flow( + hass, subentry.subentry_id ) - options = await hass.config_entries.options.async_configure( + + assert options_flow["type"] is FlowResultType.FORM + assert options_flow["step_id"] == "set_options" + + options = await hass.config_entries.subentries.async_configure( options_flow["flow_id"], { ollama.CONF_PROMPT: "test prompt", @@ -172,8 +210,10 @@ async def test_options( }, ) await hass.async_block_till_done() - assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"] == { + + assert options["type"] is FlowResultType.ABORT + assert options["reason"] == "reconfigure_successful" + assert subentry.data == { ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100, ollama.CONF_NUM_CTX: 32768, @@ -181,6 +221,22 @@ async def test_options( } +async def test_creating_conversation_subentry_not_loaded( + hass: HomeAssistant, + mock_init_component, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating a conversation subentry when entry is not loaded.""" + await hass.config_entries.async_unload(mock_config_entry.entry_id) + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "conversation"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "entry_not_loaded" + + @pytest.mark.parametrize( ("side_effect", "error"), [ diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index e83c2a3495f..cebb185bd08 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -35,7 +35,7 @@ async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]: yield msg -@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) +@pytest.mark.parametrize("agent_id", [None, "conversation.ollama_conversation"]) async def test_chat( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -149,9 +149,11 @@ async def test_template_variables( mock_user.id = "12345" mock_user.name = "Test User" - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ + subentry, + data={ "prompt": ( "The user name is {{ user_name }}. " "The user id is {{ llm_context.context.user_id }}." @@ -382,10 +384,12 @@ async def test_unknown_hass_api( mock_init_component, ) -> None: """Test when we reference an API that no longer exists.""" - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ - **mock_config_entry.options, + subentry, + data={ + **subentry.data, CONF_LLM_HASS_API: "non-existing", }, ) @@ -518,8 +522,9 @@ async def test_message_history_unlimited( with ( patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat, ): - hass.config_entries.async_update_entry( - mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( + mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0} ) for i in range(100): result = await conversation.async_converse( @@ -563,9 +568,11 @@ async def test_template_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: """Test that template error handling works.""" - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ + subentry, + data={ "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", }, ) @@ -593,7 +600,7 @@ async def test_conversation_agent( ) assert agent.supported_languages == MATCH_ALL - state = hass.states.get("conversation.mock_title") + state = hass.states.get("conversation.ollama_conversation") assert state assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0 @@ -609,7 +616,7 @@ async def test_conversation_agent_with_assist( ) assert agent.supported_languages == MATCH_ALL - state = hass.states.get("conversation.mock_title") + state = hass.states.get("conversation.ollama_conversation") assert state assert ( state.attributes[ATTR_SUPPORTED_FEATURES] @@ -642,7 +649,7 @@ async def test_options( "test message", None, Context(), - agent_id="conversation.mock_title", + agent_id="conversation.ollama_conversation", ) assert mock_chat.call_count == 1 @@ -667,9 +674,11 @@ async def test_reasoning_filter( entry = MockConfigEntry() entry.add_to_hass(hass) - hass.config_entries.async_update_entry( + subentry = next(iter(mock_config_entry.subentries.values())) + hass.config_entries.async_update_subentry( mock_config_entry, - options={ + subentry, + data={ ollama.CONF_THINK: think, }, ) diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index d1074226837..e11eb98451a 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -6,9 +6,13 @@ from httpx import ConnectError import pytest from homeassistant.components import ollama +from homeassistant.components.ollama.const import DOMAIN from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component +from . import TEST_OPTIONS, TEST_USER_DATA + from tests.common import MockConfigEntry @@ -34,3 +38,250 @@ async def test_init_error( assert await async_setup_component(hass, ollama.DOMAIN, {}) await hass.async_block_till_done() assert error in caplog.text + + +async def test_migration_from_v1_to_v2( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2.""" + # Create a v1 config entry with conversation options and an entity + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data=TEST_USER_DATA, + options=TEST_OPTIONS, + version=1, + title="llama-3.2-8b", + ) + mock_config_entry.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity = entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="llama_3_2_8b", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + + assert mock_config_entry.version == 2 + assert mock_config_entry.data == TEST_USER_DATA + assert mock_config_entry.options == {} + + assert len(mock_config_entry.subentries) == 1 + + subentry = next(iter(mock_config_entry.subentries.values())) + assert subentry.unique_id is None + assert subentry.title == "llama-3.2-8b" + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + + migrated_entity = entity_registry.async_get(entity.entity_id) + assert migrated_entity is not None + assert migrated_entity.config_entry_id == mock_config_entry.entry_id + assert migrated_entity.config_subentry_id == subentry.subentry_id + assert migrated_entity.unique_id == subentry.subentry_id + + # Check device migration + assert not device_registry.async_get_device( + identifiers={(DOMAIN, mock_config_entry.entry_id)} + ) + assert ( + migrated_device := device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + ) + assert migrated_device.identifiers == {(DOMAIN, subentry.subentry_id)} + assert migrated_device.id == device.id + + +async def test_migration_from_v1_to_v2_with_multiple_urls( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with different URLs.""" + # Create two v1 config entries with different URLs + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama 1", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11435", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama 1", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="ollama_1", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Ollama", + model="Ollama 2", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="ollama_2", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 2 + + for idx, entry in enumerate(entries): + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 1 + subentry = list(entry.subentries.values())[0] + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + assert subentry.title == f"Ollama {idx + 1}" + + dev = device_registry.async_get_device( + identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} + ) + assert dev is not None + + +async def test_migration_from_v1_to_v2_with_same_urls( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with same URLs consolidates entries.""" + # Create two v1 config entries with the same URL + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL + options=TEST_OPTIONS, + version=1, + title="Ollama 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="ollama", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="ollama_2", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + # Should have only one entry left (consolidated) + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + + entry = entries[0] + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 2 # Two subentries from the two original entries + + # Check both subentries exist with correct data + subentries = list(entry.subentries.values()) + titles = [sub.title for sub in subentries] + assert "Ollama" in titles + assert "Ollama 2" in titles + + for subentry in subentries: + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + + # Check devices were migrated correctly + dev = device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + assert dev is not None