Ollama: Migrate pick model to subentry (#147944)

This commit is contained in:
Paulus Schoutsen 2025-07-02 15:20:42 +02:00 committed by GitHub
parent 943fb9948b
commit f50ef79c72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 655 additions and 333 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from types import MappingProxyType
import httpx import httpx
import ollama import ollama
@ -100,8 +101,12 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
for entry in entries: for entry in entries:
use_existing = False use_existing = False
# Create subentry with model from entry.data and options from entry.options
subentry_data = entry.options.copy()
subentry_data[CONF_MODEL] = entry.data[CONF_MODEL]
subentry = ConfigSubentry( subentry = ConfigSubentry(
data=entry.options, data=MappingProxyType(subentry_data),
subentry_type="conversation", subentry_type="conversation",
title=entry.title, title=entry.title,
unique_id=None, unique_id=None,
@ -154,9 +159,11 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
entry, entry,
title=DEFAULT_NAME, title=DEFAULT_NAME,
# Update parent entry to only keep URL, remove model
data={CONF_URL: entry.data[CONF_URL]},
options={}, options={},
version=2, version=3,
minor_version=2, minor_version=1,
) )
@ -164,7 +171,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
"""Migrate entry.""" """Migrate entry."""
_LOGGER.debug("Migrating from version %s:%s", entry.version, entry.minor_version) _LOGGER.debug("Migrating from version %s:%s", entry.version, entry.minor_version)
if entry.version > 2: if entry.version > 3:
# This means the user has downgraded from a future version # This means the user has downgraded from a future version
return False return False
@ -182,6 +189,25 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
hass.config_entries.async_update_entry(entry, minor_version=2) hass.config_entries.async_update_entry(entry, minor_version=2)
if entry.version == 2 and entry.minor_version == 2:
# Update subentries to include the model
for subentry in entry.subentries.values():
if subentry.subentry_type == "conversation":
updated_data = dict(subentry.data)
updated_data[CONF_MODEL] = entry.data[CONF_MODEL]
hass.config_entries.async_update_subentry(
entry, subentry, data=MappingProxyType(updated_data)
)
# Update main entry to remove model and bump version
hass.config_entries.async_update_entry(
entry,
data={CONF_URL: entry.data[CONF_URL]},
version=3,
minor_version=1,
)
_LOGGER.debug( _LOGGER.debug(
"Migration to version %s:%s successful", entry.version, entry.minor_version "Migration to version %s:%s successful", entry.version, entry.minor_version
) )

View File

@ -22,7 +22,7 @@ from homeassistant.config_entries import (
) )
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm from homeassistant.helpers import config_validation as cv, llm
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
BooleanSelector, BooleanSelector,
NumberSelector, NumberSelector,
@ -38,6 +38,7 @@ from homeassistant.helpers.selector import (
) )
from homeassistant.util.ssl import get_default_context from homeassistant.util.ssl import get_default_context
from . import OllamaConfigEntry
from .const import ( from .const import (
CONF_KEEP_ALIVE, CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, CONF_MAX_HISTORY,
@ -72,43 +73,43 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Ollama.""" """Handle a config flow for Ollama."""
VERSION = 2 VERSION = 3
MINOR_VERSION = 2 MINOR_VERSION = 1
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize config flow.""" """Initialize config flow."""
self.url: str | None = None self.url: str | None = None
self.model: str | None = None
self.client: ollama.AsyncClient | None = None
self.download_task: asyncio.Task | None = None
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle the initial step.""" """Handle the initial step."""
user_input = user_input or {} if user_input is None:
self.url = user_input.get(CONF_URL, self.url)
self.model = user_input.get(CONF_MODEL, self.model)
if self.url is None:
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False step_id="user", data_schema=STEP_USER_DATA_SCHEMA
) )
errors = {} errors = {}
url = user_input[CONF_URL]
self._async_abort_entries_match({CONF_URL: self.url}) self._async_abort_entries_match({CONF_URL: url})
try: try:
self.client = ollama.AsyncClient( url = cv.url(url)
host=self.url, verify=get_default_context() except vol.Invalid:
errors["base"] = "invalid_url"
return self.async_show_form(
step_id="user",
data_schema=self.add_suggested_values_to_schema(
STEP_USER_DATA_SCHEMA, user_input
),
errors=errors,
) )
async with asyncio.timeout(DEFAULT_TIMEOUT):
response = await self.client.list()
downloaded_models: set[str] = { try:
model_info["model"] for model_info in response.get("models", []) client = ollama.AsyncClient(host=url, verify=get_default_context())
} async with asyncio.timeout(DEFAULT_TIMEOUT):
await client.list()
except (TimeoutError, httpx.ConnectError): except (TimeoutError, httpx.ConnectError):
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
except Exception: except Exception:
@ -117,10 +118,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
if errors: if errors:
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors step_id="user",
data_schema=self.add_suggested_values_to_schema(
STEP_USER_DATA_SCHEMA, user_input
),
errors=errors,
) )
if self.model is None: return self.async_create_entry(
title=url,
data={CONF_URL: url},
)
@classmethod
@callback
def async_get_supported_subentry_types(
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
def __init__(self) -> None:
"""Initialize the subentry flow."""
super().__init__()
self._name: str | None = None
self._model: str | None = None
self.download_task: asyncio.Task | None = None
self._config_data: dict[str, Any] | None = None
@property
def _is_new(self) -> bool:
"""Return if this is a new subentry."""
return self.source == "user"
@property
def _client(self) -> ollama.AsyncClient:
"""Return the Ollama client."""
entry: OllamaConfigEntry = self._get_entry()
return entry.runtime_data
async def async_step_set_options(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Handle model selection and configuration step."""
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
if user_input is None:
# Get available models from Ollama server
try:
async with asyncio.timeout(DEFAULT_TIMEOUT):
response = await self._client.list()
downloaded_models: set[str] = {
model_info["model"] for model_info in response.get("models", [])
}
except (TimeoutError, httpx.ConnectError, httpx.HTTPError):
_LOGGER.exception("Failed to get models from Ollama server")
return self.async_abort(reason="cannot_connect")
# Show models that have been downloaded first, followed by all known # Show models that have been downloaded first, followed by all known
# models (only latest tags). # models (only latest tags).
models_to_list = [ models_to_list = [
@ -131,52 +191,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
for m in sorted(MODEL_NAMES) for m in sorted(MODEL_NAMES)
if m not in downloaded_models if m not in downloaded_models
] ]
model_step_schema = vol.Schema(
{ if self._is_new:
vol.Required( options = {}
CONF_MODEL, description={"suggested_value": DEFAULT_MODEL} else:
): SelectSelector( options = self._get_reconfigure_subentry().data.copy()
SelectSelectorConfig(options=models_to_list, custom_value=True)
),
}
)
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="set_options",
data_schema=model_step_schema, data_schema=vol.Schema(
ollama_config_option_schema(
self.hass, self._is_new, options, models_to_list
)
),
) )
if self.model not in downloaded_models: self._model = user_input[CONF_MODEL]
# Ollama server needs to download model first if self._is_new:
return await self.async_step_download() self._name = user_input.pop(CONF_NAME)
return self.async_create_entry( # Check if model needs to be downloaded
title=self.url, try:
data={CONF_URL: self.url, CONF_MODEL: self.model}, async with asyncio.timeout(DEFAULT_TIMEOUT):
subentries=[ response = await self._client.list()
{
"subentry_type": "conversation", currently_downloaded_models: set[str] = {
"data": {}, model_info["model"] for model_info in response.get("models", [])
"title": _get_title(self.model), }
"unique_id": None,
} if self._model not in currently_downloaded_models:
], # Store the user input to use after download
self._config_data = user_input
# Ollama server needs to download model first
return await self.async_step_download()
except Exception:
_LOGGER.exception("Failed to check model availability")
return self.async_abort(reason="cannot_connect")
# Model is already downloaded, create/update the entry
if self._is_new:
return self.async_create_entry(
title=self._name,
data=user_input,
)
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=user_input,
) )
async def async_step_download( async def async_step_download(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> SubentryFlowResult:
"""Step to wait for Ollama server to download a model.""" """Step to wait for Ollama server to download a model."""
assert self.model is not None assert self._model is not None
assert self.client is not None
if self.download_task is None: if self.download_task is None:
# Tell Ollama server to pull the model. # Tell Ollama server to pull the model.
# The task will block until the model and metadata are fully # The task will block until the model and metadata are fully
# downloaded. # downloaded.
self.download_task = self.hass.async_create_background_task( self.download_task = self.hass.async_create_background_task(
self.client.pull(self.model), self._client.pull(self._model),
f"Downloading {self.model}", f"Downloading {self._model}",
) )
if self.download_task.done(): if self.download_task.done():
@ -192,80 +269,28 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
progress_task=self.download_task, progress_task=self.download_task,
) )
async def async_step_finish(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step after model downloading has succeeded."""
assert self.url is not None
assert self.model is not None
return self.async_create_entry(
title=_get_title(self.model),
data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": _get_title(self.model),
"unique_id": None,
}
],
)
async def async_step_failed( async def async_step_failed(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> SubentryFlowResult:
"""Step after model downloading has failed.""" """Step after model downloading has failed."""
return self.async_abort(reason="download_failed") return self.async_abort(reason="download_failed")
@classmethod async def async_step_finish(
@callback
def async_get_supported_subentry_types(
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
@property
def _is_new(self) -> bool:
"""Return if this is a new subentry."""
return self.source == "user"
async def async_step_set_options(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult: ) -> SubentryFlowResult:
"""Set conversation options.""" """Step after model downloading has succeeded."""
# abort if entry is not loaded assert self._config_data is not None
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
errors: dict[str, str] = {} # Model download completed, create/update the entry with stored config
if self._is_new:
if user_input is None:
if self._is_new:
options = {}
else:
options = self._get_reconfigure_subentry().data.copy()
elif self._is_new:
return self.async_create_entry( return self.async_create_entry(
title=user_input.pop(CONF_NAME), title=self._name,
data=user_input, data=self._config_data,
) )
else: return self.async_update_and_abort(
return self.async_update_and_abort( self._get_entry(),
self._get_entry(), self._get_reconfigure_subentry(),
self._get_reconfigure_subentry(), data=self._config_data,
data=user_input,
)
schema = ollama_config_option_schema(self.hass, self._is_new, options)
return self.async_show_form(
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
) )
async_step_user = async_step_set_options async_step_user = async_step_set_options
@ -273,19 +298,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
def ollama_config_option_schema( def ollama_config_option_schema(
hass: HomeAssistant, is_new: bool, options: Mapping[str, Any] hass: HomeAssistant,
is_new: bool,
options: Mapping[str, Any],
models_to_list: list[SelectOptionDict],
) -> dict: ) -> dict:
"""Ollama options schema.""" """Ollama options schema."""
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
]
if is_new: if is_new:
schema: dict[vol.Required | vol.Optional, Any] = { schema: dict = {
vol.Required(CONF_NAME, default="Ollama Conversation"): str, vol.Required(CONF_NAME, default="Ollama Conversation"): str,
} }
else: else:
@ -293,6 +313,12 @@ def ollama_config_option_schema(
schema.update( schema.update(
{ {
vol.Required(
CONF_MODEL,
description={"suggested_value": options.get(CONF_MODEL, DEFAULT_MODEL)},
): SelectSelector(
SelectSelectorConfig(options=models_to_list, custom_value=True)
),
vol.Optional( vol.Optional(
CONF_PROMPT, CONF_PROMPT,
description={ description={
@ -304,7 +330,18 @@ def ollama_config_option_schema(
vol.Optional( vol.Optional(
CONF_LLM_HASS_API, CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, description={"suggested_value": options.get(CONF_LLM_HASS_API)},
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), ): SelectSelector(
SelectSelectorConfig(
options=[
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
],
multiple=True,
)
),
vol.Optional( vol.Optional(
CONF_NUM_CTX, CONF_NUM_CTX,
description={ description={
@ -350,11 +387,3 @@ def ollama_config_option_schema(
) )
return schema return schema
def _get_title(model: str) -> str:
"""Get title for config entry."""
if model.endswith(":latest"):
model = model.split(":", maxsplit=1)[0]
return model

View File

@ -166,11 +166,14 @@ class OllamaBaseLLMEntity(Entity):
self.subentry = subentry self.subentry = subentry
self._attr_name = subentry.title self._attr_name = subentry.title
self._attr_unique_id = subentry.subentry_id self._attr_unique_id = subentry.subentry_id
model, _, version = subentry.data[CONF_MODEL].partition(":")
self._attr_device_info = dr.DeviceInfo( self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)}, identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title, name=subentry.title,
manufacturer="Ollama", manufacturer="Ollama",
model=entry.data[CONF_MODEL], model=model,
sw_version=version or "latest",
entry_type=dr.DeviceEntryType.SERVICE, entry_type=dr.DeviceEntryType.SERVICE,
) )

View File

@ -3,24 +3,17 @@
"step": { "step": {
"user": { "user": {
"data": { "data": {
"url": "[%key:common::config_flow::data::url%]", "url": "[%key:common::config_flow::data::url%]"
"model": "Model"
} }
},
"download": {
"title": "Downloading model"
} }
}, },
"abort": { "abort": {
"download_failed": "Model downloading failed",
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]" "already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
}, },
"error": { "error": {
"invalid_url": "[%key:common::config_flow::error::invalid_host%]",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"unknown": "[%key:common::config_flow::error::unknown%]" "unknown": "[%key:common::config_flow::error::unknown%]"
},
"progress": {
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
} }
}, },
"config_subentries": { "config_subentries": {
@ -33,6 +26,7 @@
"step": { "step": {
"set_options": { "set_options": {
"data": { "data": {
"model": "Model",
"name": "[%key:common::config_flow::data::name%]", "name": "[%key:common::config_flow::data::name%]",
"prompt": "Instructions", "prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
@ -47,11 +41,19 @@
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.", "num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency." "think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
} }
},
"download": {
"title": "Downloading model"
} }
}, },
"abort": { "abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
"entry_not_loaded": "Cannot add things while the configuration is disabled." "entry_not_loaded": "Failed to add agent. The configuration is disabled.",
"download_failed": "Model downloading failed",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
},
"progress": {
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
} }
} }
} }

View File

@ -5,10 +5,10 @@ from homeassistant.helpers import llm
TEST_USER_DATA = { TEST_USER_DATA = {
ollama.CONF_URL: "http://localhost:11434", ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test model",
} }
TEST_OPTIONS = { TEST_OPTIONS = {
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
ollama.CONF_MAX_HISTORY: 2, ollama.CONF_MAX_HISTORY: 2,
ollama.CONF_MODEL: "test_model:latest",
} }

View File

@ -30,10 +30,11 @@ def mock_config_entry(
entry = MockConfigEntry( entry = MockConfigEntry(
domain=ollama.DOMAIN, domain=ollama.DOMAIN,
data=TEST_USER_DATA, data=TEST_USER_DATA,
version=2, version=3,
minor_version=1,
subentries_data=[ subentries_data=[
{ {
"data": mock_config_entry_options, "data": {**TEST_OPTIONS, **mock_config_entry_options},
"subentry_type": "conversation", "subentry_type": "conversation",
"title": "Ollama Conversation", "title": "Ollama Conversation",
"unique_id": None, "unique_id": None,
@ -49,10 +50,14 @@ def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Mock a config entry with assist.""" """Mock a config entry with assist."""
subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry( hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
next(iter(mock_config_entry.subentries.values())), subentry,
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, data={
**subentry.data,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
) )
return mock_config_entry return mock_config_entry

View File

@ -8,6 +8,7 @@ import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import ollama from homeassistant.components import ollama
from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@ -17,7 +18,7 @@ TEST_MODEL = "test_model:latest"
async def test_form(hass: HomeAssistant) -> None: async def test_form(hass: HomeAssistant) -> None:
"""Test flow when the model is already downloaded.""" """Test flow when configuring URL only."""
# Pretend we already set up a config entry. # Pretend we already set up a config entry.
hass.config.components.add(ollama.DOMAIN) hass.config.components.add(ollama.DOMAIN)
MockConfigEntry( MockConfigEntry(
@ -34,7 +35,6 @@ async def test_form(hass: HomeAssistant) -> None:
with ( with (
patch( patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
# test model is already "downloaded"
return_value={"models": [{"model": TEST_MODEL}]}, return_value={"models": [{"model": TEST_MODEL}]},
), ),
patch( patch(
@ -42,24 +42,17 @@ async def test_form(hass: HomeAssistant) -> None:
return_value=True, return_value=True,
) as mock_setup_entry, ) as mock_setup_entry,
): ):
# Step 1: URL
result2 = await hass.config_entries.flow.async_configure( result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"} result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
) )
await hass.async_block_till_done() await hass.async_block_till_done()
# Step 2: model assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["type"] is FlowResultType.FORM assert result2["data"] == {
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
)
await hass.async_block_till_done()
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["data"] == {
ollama.CONF_URL: "http://localhost:11434", ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: TEST_MODEL,
} }
# No subentries created by default
assert len(result2.get("subentries", [])) == 0
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@ -94,98 +87,6 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
assert result["reason"] == "already_configured" assert result["reason"] == "already_configured"
async def test_form_need_download(hass: HomeAssistant) -> None:
"""Test flow when a model needs to be downloaded."""
# Pretend we already set up a config entry.
hass.config.components.add(ollama.DOMAIN)
MockConfigEntry(
domain=ollama.DOMAIN,
state=config_entries.ConfigEntryState.LOADED,
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] is None
pull_ready = asyncio.Event()
pull_called = asyncio.Event()
pull_model: str | None = None
async def pull(self, model: str, *args, **kwargs) -> None:
nonlocal pull_model
async with asyncio.timeout(1):
await pull_ready.wait()
pull_model = model
pull_called.set()
with (
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
# No models are downloaded
return_value={},
),
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
pull,
),
patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
) as mock_setup_entry,
):
# Step 1: URL
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
await hass.async_block_till_done()
# Step 2: model
assert result2["type"] is FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
)
await hass.async_block_till_done()
# Step 3: download
assert result3["type"] is FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(
result3["flow_id"],
)
await hass.async_block_till_done()
# Run again without the task finishing.
# We should still be downloading.
assert result4["type"] is FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
)
await hass.async_block_till_done()
assert result4["type"] is FlowResultType.SHOW_PROGRESS
# Signal fake pull method to complete
pull_ready.set()
async with asyncio.timeout(1):
await pull_called.wait()
assert pull_model == TEST_MODEL
# Step 4: finish
result5 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
)
assert result5["type"] is FlowResultType.CREATE_ENTRY
assert result5["data"] == {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: TEST_MODEL,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_subentry_options( async def test_subentry_options(
hass: HomeAssistant, mock_config_entry, mock_init_component hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None: ) -> None:
@ -193,34 +94,84 @@ async def test_subentry_options(
subentry = next(iter(mock_config_entry.subentries.values())) subentry = next(iter(mock_config_entry.subentries.values()))
# Test reconfiguration # Test reconfiguration
options_flow = await mock_config_entry.start_subentry_reconfigure_flow( with patch(
hass, subentry.subentry_id "ollama.AsyncClient.list",
) return_value={"models": [{"model": TEST_MODEL}]},
):
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
assert options_flow["type"] is FlowResultType.FORM assert options_flow["type"] is FlowResultType.FORM
assert options_flow["step_id"] == "set_options" assert options_flow["step_id"] == "set_options"
options = await hass.config_entries.subentries.async_configure( options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"], options_flow["flow_id"],
{ {
ollama.CONF_PROMPT: "test prompt", ollama.CONF_MODEL: TEST_MODEL,
ollama.CONF_MAX_HISTORY: 100, ollama.CONF_PROMPT: "test prompt",
ollama.CONF_NUM_CTX: 32768, ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_THINK: True, ollama.CONF_NUM_CTX: 32768,
}, ollama.CONF_THINK: True,
) },
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert options["type"] is FlowResultType.ABORT assert options["type"] is FlowResultType.ABORT
assert options["reason"] == "reconfigure_successful" assert options["reason"] == "reconfigure_successful"
assert subentry.data == { assert subentry.data == {
ollama.CONF_MODEL: TEST_MODEL,
ollama.CONF_PROMPT: "test prompt", ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100, ollama.CONF_MAX_HISTORY: 100.0,
ollama.CONF_NUM_CTX: 32768, ollama.CONF_NUM_CTX: 32768.0,
ollama.CONF_THINK: True, ollama.CONF_THINK: True,
} }
async def test_creating_new_conversation_subentry(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test creating a new conversation subentry includes name field."""
# Start a new subentry flow
with patch(
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
assert new_flow["step_id"] == "set_options"
# Configure the new subentry with name field
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: TEST_MODEL,
CONF_NAME: "New Test Conversation",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "New Test Conversation"
assert result["data"] == {
ollama.CONF_MODEL: TEST_MODEL,
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50.0,
ollama.CONF_NUM_CTX: 16384.0,
ollama.CONF_THINK: False,
}
async def test_creating_conversation_subentry_not_loaded( async def test_creating_conversation_subentry_not_loaded(
hass: HomeAssistant, hass: HomeAssistant,
mock_init_component, mock_init_component,
@ -237,6 +188,125 @@ async def test_creating_conversation_subentry_not_loaded(
assert result["reason"] == "entry_not_loaded" assert result["reason"] == "entry_not_loaded"
async def test_subentry_need_download(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when model needs to be downloaded."""
async def delayed_pull(self, model: str) -> None:
"""Simulate a delayed model download."""
assert model == "llama3.2:latest"
await asyncio.sleep(0) # yield the event loop 1 iteration
with (
patch(
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
),
patch("ollama.AsyncClient.pull", delayed_pull),
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM, new_flow
assert new_flow["step_id"] == "set_options"
# Configure the new subentry with a model that needs downloading
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: "llama3.2:latest", # not cached
CONF_NAME: "New Test Conversation",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
assert result["type"] is FlowResultType.SHOW_PROGRESS
assert result["step_id"] == "download"
assert result["progress_action"] == "download"
await hass.async_block_till_done()
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"], {}
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "New Test Conversation"
assert result["data"] == {
ollama.CONF_MODEL: "llama3.2:latest",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50.0,
ollama.CONF_NUM_CTX: 16384.0,
ollama.CONF_THINK: False,
}
async def test_subentry_download_error(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when model download fails."""
async def delayed_pull(self, model: str) -> None:
"""Simulate a delayed model download."""
await asyncio.sleep(0) # yield
raise RuntimeError("Download failed")
with (
patch(
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
),
patch("ollama.AsyncClient.pull", delayed_pull),
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
assert new_flow["step_id"] == "set_options"
# Configure with a model that needs downloading but will fail
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: "llama3.2:latest",
CONF_NAME: "New Test Conversation",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
# Should show progress flow result for download
assert result["type"] is FlowResultType.SHOW_PROGRESS
assert result["step_id"] == "download"
assert result["progress_action"] == "download"
# Wait for download task to complete (with error)
await hass.async_block_till_done()
# Submit the progress flow - should get failure
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"], {}
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "download_failed"
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
[ [
@ -262,40 +332,132 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
assert result2["errors"] == {"base": error} assert result2["errors"] == {"base": error}
async def test_download_error(hass: HomeAssistant) -> None: async def test_form_invalid_url(hass: HomeAssistant) -> None:
"""Test we handle errors while downloading a model.""" """Test we handle invalid URL."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
async def _delayed_runtime_error(*args, **kwargs): result2 = await hass.config_entries.flow.async_configure(
await asyncio.sleep(0) result["flow_id"], {ollama.CONF_URL: "not-a-valid-url"}
raise RuntimeError )
assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": "invalid_url"}
async def test_subentry_connection_error(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when connection to Ollama server fails."""
with patch(
"ollama.AsyncClient.list",
side_effect=ConnectError("Connection failed"),
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.ABORT
assert new_flow["reason"] == "cannot_connect"
async def test_subentry_model_check_exception(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when checking model availability throws exception."""
with patch(
"ollama.AsyncClient.list",
side_effect=[
{"models": [{"model": TEST_MODEL}]}, # First call succeeds
RuntimeError("Failed to check models"), # Second call fails
],
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
assert new_flow["step_id"] == "set_options"
# Configure with a model, should fail when checking availability
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: "new_model:latest",
CONF_NAME: "Test Conversation",
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "cannot_connect"
async def test_subentry_reconfigure_with_download(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test reconfiguring subentry when model needs to be downloaded."""
subentry = next(iter(mock_config_entry.subentries.values()))
async def delayed_pull(self, model: str) -> None:
"""Simulate a delayed model download."""
assert model == "llama3.2:latest"
await asyncio.sleep(0) # yield the event loop
with ( with (
patch( patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", "ollama.AsyncClient.list",
return_value={}, return_value={"models": [{"model": TEST_MODEL}]},
),
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
_delayed_runtime_error,
), ),
patch("ollama.AsyncClient.pull", delayed_pull),
): ):
result2 = await hass.config_entries.flow.async_configure( reconfigure_flow = await mock_config_entry.start_subentry_reconfigure_flow(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"} hass, subentry.subentry_id
) )
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.FORM assert reconfigure_flow["type"] is FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure( assert reconfigure_flow["step_id"] == "set_options"
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
# Reconfigure with a model that needs downloading
result = await hass.config_entries.subentries.async_configure(
reconfigure_flow["flow_id"],
{
ollama.CONF_MODEL: "llama3.2:latest",
ollama.CONF_PROMPT: "updated prompt",
ollama.CONF_MAX_HISTORY: 75,
ollama.CONF_NUM_CTX: 8192,
ollama.CONF_THINK: True,
},
) )
assert result["type"] is FlowResultType.SHOW_PROGRESS
assert result["step_id"] == "download"
await hass.async_block_till_done() await hass.async_block_till_done()
assert result3["type"] is FlowResultType.SHOW_PROGRESS # Finish download
result4 = await hass.config_entries.flow.async_configure(result3["flow_id"]) result = await hass.config_entries.subentries.async_configure(
await hass.async_block_till_done() reconfigure_flow["flow_id"], {}
)
assert result4["type"] is FlowResultType.ABORT assert result["type"] is FlowResultType.ABORT
assert result4["reason"] == "download_failed" assert result["reason"] == "reconfigure_successful"
assert subentry.data == {
ollama.CONF_MODEL: "llama3.2:latest",
ollama.CONF_PROMPT: "updated prompt",
ollama.CONF_MAX_HISTORY: 75.0,
ollama.CONF_NUM_CTX: 8192.0,
ollama.CONF_THINK: True,
}

View File

@ -15,7 +15,12 @@ from homeassistant.components.conversation import trace
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
intent,
llm,
)
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -68,7 +73,7 @@ async def test_chat(
args = mock_chat.call_args.kwargs args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"] prompt = args["messages"][0]["content"]
assert args["model"] == "test model" assert args["model"] == "test_model:latest"
assert args["messages"] == [ assert args["messages"] == [
Message(role="system", content=prompt), Message(role="system", content=prompt),
Message(role="user", content="test message"), Message(role="user", content="test message"),
@ -128,7 +133,7 @@ async def test_chat_stream(
args = mock_chat.call_args.kwargs args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"] prompt = args["messages"][0]["content"]
assert args["model"] == "test model" assert args["model"] == "test_model:latest"
assert args["messages"] == [ assert args["messages"] == [
Message(role="system", content=prompt), Message(role="system", content=prompt),
Message(role="user", content="test message"), Message(role="user", content="test message"),
@ -158,6 +163,7 @@ async def test_template_variables(
"The user name is {{ user_name }}. " "The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}." "The user id is {{ llm_context.context.user_id }}."
), ),
ollama.CONF_MODEL: "test_model:latest",
}, },
) )
with ( with (
@ -524,7 +530,9 @@ async def test_message_history_unlimited(
): ):
subentry = next(iter(mock_config_entry.subentries.values())) subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry( hass.config_entries.async_update_subentry(
mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0} mock_config_entry,
subentry,
data={**subentry.data, ollama.CONF_MAX_HISTORY: 0},
) )
for i in range(100): for i in range(100):
result = await conversation.async_converse( result = await conversation.async_converse(
@ -573,6 +581,7 @@ async def test_template_error(
mock_config_entry, mock_config_entry,
subentry, subentry,
data={ data={
**subentry.data,
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
}, },
) )
@ -593,6 +602,8 @@ async def test_conversation_agent(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_init_component, mock_init_component,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test OllamaConversationEntity.""" """Test OllamaConversationEntity."""
agent = conversation.get_agent_manager(hass).async_get_agent( agent = conversation.get_agent_manager(hass).async_get_agent(
@ -604,6 +615,24 @@ async def test_conversation_agent(
assert state assert state
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0 assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
entity_entry = entity_registry.async_get("conversation.ollama_conversation")
assert entity_entry
subentry = mock_config_entry.subentries.get(entity_entry.unique_id)
assert subentry
assert entity_entry.original_name == subentry.title
device_entry = device_registry.async_get(entity_entry.device_id)
assert device_entry
assert device_entry.identifiers == {(ollama.DOMAIN, subentry.subentry_id)}
assert device_entry.name == subentry.title
assert device_entry.manufacturer == "Ollama"
assert device_entry.entry_type == dr.DeviceEntryType.SERVICE
model, _, version = subentry.data[ollama.CONF_MODEL].partition(":")
assert device_entry.model == model
assert device_entry.sw_version == version
async def test_conversation_agent_with_assist( async def test_conversation_agent_with_assist(
hass: HomeAssistant, hass: HomeAssistant,
@ -679,6 +708,7 @@ async def test_reasoning_filter(
mock_config_entry, mock_config_entry,
subentry, subentry,
data={ data={
**subentry.data,
ollama.CONF_THINK: think, ollama.CONF_THINK: think,
}, },
) )

View File

@ -9,13 +9,26 @@ from homeassistant.components import ollama
from homeassistant.components.ollama.const import DOMAIN from homeassistant.components.ollama.const import DOMAIN
from homeassistant.config_entries import ConfigSubentryData from homeassistant.config_entries import ConfigSubentryData
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er, llm
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import TEST_OPTIONS, TEST_USER_DATA from . import TEST_OPTIONS
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
V1_TEST_USER_DATA = {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model:latest",
}
V1_TEST_OPTIONS = {
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
ollama.CONF_MAX_HISTORY: 2,
}
V21_TEST_USER_DATA = V1_TEST_USER_DATA
V21_TEST_OPTIONS = V1_TEST_OPTIONS
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
@ -41,17 +54,17 @@ async def test_init_error(
assert error in caplog.text assert error in caplog.text
async def test_migration_from_v1_to_v2( async def test_migration_from_v1(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test migration from version 1 to version 2.""" """Test migration from version 1."""
# Create a v1 config entry with conversation options and an entity # Create a v1 config entry with conversation options and an entity
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data=TEST_USER_DATA, data=V1_TEST_USER_DATA,
options=TEST_OPTIONS, options=V1_TEST_OPTIONS,
version=1, version=1,
title="llama-3.2-8b", title="llama-3.2-8b",
) )
@ -81,9 +94,10 @@ async def test_migration_from_v1_to_v2(
): ):
await hass.config_entries.async_setup(mock_config_entry.entry_id) await hass.config_entries.async_setup(mock_config_entry.entry_id)
assert mock_config_entry.version == 2 assert mock_config_entry.version == 3
assert mock_config_entry.minor_version == 2 assert mock_config_entry.minor_version == 1
assert mock_config_entry.data == TEST_USER_DATA # After migration, parent entry should only have URL
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
assert mock_config_entry.options == {} assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 1 assert len(mock_config_entry.subentries) == 1
@ -92,7 +106,9 @@ async def test_migration_from_v1_to_v2(
assert subentry.unique_id is None assert subentry.unique_id is None
assert subentry.title == "llama-3.2-8b" assert subentry.title == "llama-3.2-8b"
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS # Subentry should now include the model from the original options
expected_subentry_data = TEST_OPTIONS.copy()
assert subentry.data == expected_subentry_data
migrated_entity = entity_registry.async_get(entity.entity_id) migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity is not None assert migrated_entity is not None
@ -117,17 +133,17 @@ async def test_migration_from_v1_to_v2(
} }
async def test_migration_from_v1_to_v2_with_multiple_urls( async def test_migration_from_v1_with_multiple_urls(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test migration from version 1 to version 2 with different URLs.""" """Test migration from version 1 with different URLs."""
# Create two v1 config entries with different URLs # Create two v1 config entries with different URLs
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS, options=V1_TEST_OPTIONS,
version=1, version=1,
title="Ollama 1", title="Ollama 1",
) )
@ -135,7 +151,7 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
mock_config_entry_2 = MockConfigEntry( mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={"url": "http://localhost:11435", "model": "llama3.2:latest"}, data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
options=TEST_OPTIONS, options=V1_TEST_OPTIONS,
version=1, version=1,
title="Ollama 2", title="Ollama 2",
) )
@ -187,13 +203,16 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
assert len(entries) == 2 assert len(entries) == 2
for idx, entry in enumerate(entries): for idx, entry in enumerate(entries):
assert entry.version == 2 assert entry.version == 3
assert entry.minor_version == 2 assert entry.minor_version == 1
assert not entry.options assert not entry.options
assert len(entry.subentries) == 1 assert len(entry.subentries) == 1
subentry = list(entry.subentries.values())[0] subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS # Subentry should include the model along with the original options
expected_subentry_data = TEST_OPTIONS.copy()
expected_subentry_data["model"] = "llama3.2:latest"
assert subentry.data == expected_subentry_data
assert subentry.title == f"Ollama {idx + 1}" assert subentry.title == f"Ollama {idx + 1}"
dev = device_registry.async_get_device( dev = device_registry.async_get_device(
@ -204,17 +223,17 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}} assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
async def test_migration_from_v1_to_v2_with_same_urls( async def test_migration_from_v1_with_same_urls(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test migration from version 1 to version 2 with same URLs consolidates entries.""" """Test migration from version 1 with same URLs consolidates entries."""
# Create two v1 config entries with the same URL # Create two v1 config entries with the same URL
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS, options=V1_TEST_OPTIONS,
version=1, version=1,
title="Ollama", title="Ollama",
) )
@ -222,7 +241,7 @@ async def test_migration_from_v1_to_v2_with_same_urls(
mock_config_entry_2 = MockConfigEntry( mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
options=TEST_OPTIONS, options=V1_TEST_OPTIONS,
version=1, version=1,
title="Ollama 2", title="Ollama 2",
) )
@ -275,8 +294,8 @@ async def test_migration_from_v1_to_v2_with_same_urls(
assert len(entries) == 1 assert len(entries) == 1
entry = entries[0] entry = entries[0]
assert entry.version == 2 assert entry.version == 3
assert entry.minor_version == 2 assert entry.minor_version == 1
assert not entry.options assert not entry.options
assert len(entry.subentries) == 2 # Two subentries from the two original entries assert len(entry.subentries) == 2 # Two subentries from the two original entries
@ -288,7 +307,10 @@ async def test_migration_from_v1_to_v2_with_same_urls(
for subentry in subentries: for subentry in subentries:
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS # Subentry should include the model along with the original options
expected_subentry_data = TEST_OPTIONS.copy()
expected_subentry_data["model"] = "llama3.2:latest"
assert subentry.data == expected_subentry_data
# Check devices were migrated correctly # Check devices were migrated correctly
dev = device_registry.async_get_device( dev = device_registry.async_get_device(
@ -301,12 +323,12 @@ async def test_migration_from_v1_to_v2_with_same_urls(
} }
async def test_migration_from_v2_1_to_v2_2( async def test_migration_from_v2_1(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test migration from version 2.1 to version 2.2. """Test migration from version 2.1.
This tests we clean up the broken migration in Home Assistant Core This tests we clean up the broken migration in Home Assistant Core
2025.7.0b0-2025.7.0b1: 2025.7.0b0-2025.7.0b1:
@ -315,20 +337,20 @@ async def test_migration_from_v2_1_to_v2_2(
# Create a v2.1 config entry with 2 subentries, devices and entities # Create a v2.1 config entry with 2 subentries, devices and entities
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data=TEST_USER_DATA, data=V21_TEST_USER_DATA,
entry_id="mock_entry_id", entry_id="mock_entry_id",
version=2, version=2,
minor_version=1, minor_version=1,
subentries_data=[ subentries_data=[
ConfigSubentryData( ConfigSubentryData(
data=TEST_OPTIONS, data=V21_TEST_OPTIONS,
subentry_id="mock_id_1", subentry_id="mock_id_1",
subentry_type="conversation", subentry_type="conversation",
title="Ollama", title="Ollama",
unique_id=None, unique_id=None,
), ),
ConfigSubentryData( ConfigSubentryData(
data=TEST_OPTIONS, data=V21_TEST_OPTIONS,
subentry_id="mock_id_2", subentry_id="mock_id_2",
subentry_type="conversation", subentry_type="conversation",
title="Ollama 2", title="Ollama 2",
@ -392,8 +414,8 @@ async def test_migration_from_v2_1_to_v2_2(
entries = hass.config_entries.async_entries(DOMAIN) entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1 assert len(entries) == 1
entry = entries[0] entry = entries[0]
assert entry.version == 2 assert entry.version == 3
assert entry.minor_version == 2 assert entry.minor_version == 1
assert not entry.options assert not entry.options
assert entry.title == "Ollama" assert entry.title == "Ollama"
assert len(entry.subentries) == 2 assert len(entry.subentries) == 2
@ -405,6 +427,7 @@ async def test_migration_from_v2_1_to_v2_2(
assert len(conversation_subentries) == 2 assert len(conversation_subentries) == 2
for subentry in conversation_subentries: for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
# Since TEST_USER_DATA no longer has a model, subentry data should be TEST_OPTIONS
assert subentry.data == TEST_OPTIONS assert subentry.data == TEST_OPTIONS
assert "Ollama" in subentry.title assert "Ollama" in subentry.title
@ -450,3 +473,45 @@ async def test_migration_from_v2_1_to_v2_2(
assert device.config_entries_subentries == { assert device.config_entries_subentries == {
mock_config_entry.entry_id: {subentry.subentry_id} mock_config_entry.entry_id: {subentry.subentry_id}
} }
async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
"""Test migration from version 2.2."""
subentry_data = ConfigSubentryData(
data=V21_TEST_USER_DATA,
subentry_type="conversation",
title="Test Conversation",
unique_id=None,
)
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model:latest", # Model still in main data
},
version=2,
minor_version=2,
subentries_data=[subentry_data],
)
mock_config_entry.add_to_hass(hass)
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
# Check migration to v3.1
assert mock_config_entry.version == 3
assert mock_config_entry.minor_version == 1
# Check that model was moved from main data to subentry
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
assert len(mock_config_entry.subentries) == 1
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.data == {
**V21_TEST_USER_DATA,
ollama.CONF_MODEL: "test_model:latest",
}