Ollama: Migrate pick model to subentry (#147944)

This commit is contained in:
Paulus Schoutsen 2025-07-02 15:20:42 +02:00 committed by GitHub
parent 943fb9948b
commit f50ef79c72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 655 additions and 333 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import logging
from types import MappingProxyType
import httpx
import ollama
@ -100,8 +101,12 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
for entry in entries:
use_existing = False
# Create subentry with model from entry.data and options from entry.options
subentry_data = entry.options.copy()
subentry_data[CONF_MODEL] = entry.data[CONF_MODEL]
subentry = ConfigSubentry(
data=entry.options,
data=MappingProxyType(subentry_data),
subentry_type="conversation",
title=entry.title,
unique_id=None,
@ -154,9 +159,11 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
hass.config_entries.async_update_entry(
entry,
title=DEFAULT_NAME,
# Update parent entry to only keep URL, remove model
data={CONF_URL: entry.data[CONF_URL]},
options={},
version=2,
minor_version=2,
version=3,
minor_version=1,
)
@ -164,7 +171,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
"""Migrate entry."""
_LOGGER.debug("Migrating from version %s:%s", entry.version, entry.minor_version)
if entry.version > 2:
if entry.version > 3:
# This means the user has downgraded from a future version
return False
@ -182,6 +189,25 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
hass.config_entries.async_update_entry(entry, minor_version=2)
if entry.version == 2 and entry.minor_version == 2:
# Update subentries to include the model
for subentry in entry.subentries.values():
if subentry.subentry_type == "conversation":
updated_data = dict(subentry.data)
updated_data[CONF_MODEL] = entry.data[CONF_MODEL]
hass.config_entries.async_update_subentry(
entry, subentry, data=MappingProxyType(updated_data)
)
# Update main entry to remove model and bump version
hass.config_entries.async_update_entry(
entry,
data={CONF_URL: entry.data[CONF_URL]},
version=3,
minor_version=1,
)
_LOGGER.debug(
"Migration to version %s:%s successful", entry.version, entry.minor_version
)

View File

@ -22,7 +22,7 @@ from homeassistant.config_entries import (
)
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm
from homeassistant.helpers import config_validation as cv, llm
from homeassistant.helpers.selector import (
BooleanSelector,
NumberSelector,
@ -38,6 +38,7 @@ from homeassistant.helpers.selector import (
)
from homeassistant.util.ssl import get_default_context
from . import OllamaConfigEntry
from .const import (
CONF_KEEP_ALIVE,
CONF_MAX_HISTORY,
@ -72,43 +73,43 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Ollama."""
VERSION = 2
MINOR_VERSION = 2
VERSION = 3
MINOR_VERSION = 1
def __init__(self) -> None:
"""Initialize config flow."""
self.url: str | None = None
self.model: str | None = None
self.client: ollama.AsyncClient | None = None
self.download_task: asyncio.Task | None = None
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
user_input = user_input or {}
self.url = user_input.get(CONF_URL, self.url)
self.model = user_input.get(CONF_MODEL, self.model)
if self.url is None:
if user_input is None:
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
)
errors = {}
url = user_input[CONF_URL]
self._async_abort_entries_match({CONF_URL: self.url})
self._async_abort_entries_match({CONF_URL: url})
try:
self.client = ollama.AsyncClient(
host=self.url, verify=get_default_context()
url = cv.url(url)
except vol.Invalid:
errors["base"] = "invalid_url"
return self.async_show_form(
step_id="user",
data_schema=self.add_suggested_values_to_schema(
STEP_USER_DATA_SCHEMA, user_input
),
errors=errors,
)
async with asyncio.timeout(DEFAULT_TIMEOUT):
response = await self.client.list()
downloaded_models: set[str] = {
model_info["model"] for model_info in response.get("models", [])
}
try:
client = ollama.AsyncClient(host=url, verify=get_default_context())
async with asyncio.timeout(DEFAULT_TIMEOUT):
await client.list()
except (TimeoutError, httpx.ConnectError):
errors["base"] = "cannot_connect"
except Exception:
@ -117,10 +118,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
if errors:
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
step_id="user",
data_schema=self.add_suggested_values_to_schema(
STEP_USER_DATA_SCHEMA, user_input
),
errors=errors,
)
if self.model is None:
return self.async_create_entry(
title=url,
data={CONF_URL: url},
)
@classmethod
@callback
def async_get_supported_subentry_types(
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
def __init__(self) -> None:
"""Initialize the subentry flow."""
super().__init__()
self._name: str | None = None
self._model: str | None = None
self.download_task: asyncio.Task | None = None
self._config_data: dict[str, Any] | None = None
@property
def _is_new(self) -> bool:
"""Return if this is a new subentry."""
return self.source == "user"
@property
def _client(self) -> ollama.AsyncClient:
"""Return the Ollama client."""
entry: OllamaConfigEntry = self._get_entry()
return entry.runtime_data
async def async_step_set_options(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Handle model selection and configuration step."""
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
if user_input is None:
# Get available models from Ollama server
try:
async with asyncio.timeout(DEFAULT_TIMEOUT):
response = await self._client.list()
downloaded_models: set[str] = {
model_info["model"] for model_info in response.get("models", [])
}
except (TimeoutError, httpx.ConnectError, httpx.HTTPError):
_LOGGER.exception("Failed to get models from Ollama server")
return self.async_abort(reason="cannot_connect")
# Show models that have been downloaded first, followed by all known
# models (only latest tags).
models_to_list = [
@ -131,52 +191,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
for m in sorted(MODEL_NAMES)
if m not in downloaded_models
]
model_step_schema = vol.Schema(
{
vol.Required(
CONF_MODEL, description={"suggested_value": DEFAULT_MODEL}
): SelectSelector(
SelectSelectorConfig(options=models_to_list, custom_value=True)
),
}
)
if self._is_new:
options = {}
else:
options = self._get_reconfigure_subentry().data.copy()
return self.async_show_form(
step_id="user",
data_schema=model_step_schema,
step_id="set_options",
data_schema=vol.Schema(
ollama_config_option_schema(
self.hass, self._is_new, options, models_to_list
)
),
)
if self.model not in downloaded_models:
# Ollama server needs to download model first
return await self.async_step_download()
self._model = user_input[CONF_MODEL]
if self._is_new:
self._name = user_input.pop(CONF_NAME)
return self.async_create_entry(
title=self.url,
data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": _get_title(self.model),
"unique_id": None,
}
],
# Check if model needs to be downloaded
try:
async with asyncio.timeout(DEFAULT_TIMEOUT):
response = await self._client.list()
currently_downloaded_models: set[str] = {
model_info["model"] for model_info in response.get("models", [])
}
if self._model not in currently_downloaded_models:
# Store the user input to use after download
self._config_data = user_input
# Ollama server needs to download model first
return await self.async_step_download()
except Exception:
_LOGGER.exception("Failed to check model availability")
return self.async_abort(reason="cannot_connect")
# Model is already downloaded, create/update the entry
if self._is_new:
return self.async_create_entry(
title=self._name,
data=user_input,
)
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=user_input,
)
async def async_step_download(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
) -> SubentryFlowResult:
"""Step to wait for Ollama server to download a model."""
assert self.model is not None
assert self.client is not None
assert self._model is not None
if self.download_task is None:
# Tell Ollama server to pull the model.
# The task will block until the model and metadata are fully
# downloaded.
self.download_task = self.hass.async_create_background_task(
self.client.pull(self.model),
f"Downloading {self.model}",
self._client.pull(self._model),
f"Downloading {self._model}",
)
if self.download_task.done():
@ -192,80 +269,28 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
progress_task=self.download_task,
)
async def async_step_finish(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step after model downloading has succeeded."""
assert self.url is not None
assert self.model is not None
return self.async_create_entry(
title=_get_title(self.model),
data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": _get_title(self.model),
"unique_id": None,
}
],
)
async def async_step_failed(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
) -> SubentryFlowResult:
"""Step after model downloading has failed."""
return self.async_abort(reason="download_failed")
@classmethod
@callback
def async_get_supported_subentry_types(
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
@property
def _is_new(self) -> bool:
"""Return if this is a new subentry."""
return self.source == "user"
async def async_step_set_options(
async def async_step_finish(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Set conversation options."""
# abort if entry is not loaded
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
"""Step after model downloading has succeeded."""
assert self._config_data is not None
errors: dict[str, str] = {}
if user_input is None:
if self._is_new:
options = {}
else:
options = self._get_reconfigure_subentry().data.copy()
elif self._is_new:
# Model download completed, create/update the entry with stored config
if self._is_new:
return self.async_create_entry(
title=user_input.pop(CONF_NAME),
data=user_input,
title=self._name,
data=self._config_data,
)
else:
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=user_input,
)
schema = ollama_config_option_schema(self.hass, self._is_new, options)
return self.async_show_form(
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=self._config_data,
)
async_step_user = async_step_set_options
@ -273,19 +298,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
def ollama_config_option_schema(
hass: HomeAssistant, is_new: bool, options: Mapping[str, Any]
hass: HomeAssistant,
is_new: bool,
options: Mapping[str, Any],
models_to_list: list[SelectOptionDict],
) -> dict:
"""Ollama options schema."""
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
]
if is_new:
schema: dict[vol.Required | vol.Optional, Any] = {
schema: dict = {
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
}
else:
@ -293,6 +313,12 @@ def ollama_config_option_schema(
schema.update(
{
vol.Required(
CONF_MODEL,
description={"suggested_value": options.get(CONF_MODEL, DEFAULT_MODEL)},
): SelectSelector(
SelectSelectorConfig(options=models_to_list, custom_value=True)
),
vol.Optional(
CONF_PROMPT,
description={
@ -304,7 +330,18 @@ def ollama_config_option_schema(
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
): SelectSelector(
SelectSelectorConfig(
options=[
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
],
multiple=True,
)
),
vol.Optional(
CONF_NUM_CTX,
description={
@ -350,11 +387,3 @@ def ollama_config_option_schema(
)
return schema
def _get_title(model: str) -> str:
"""Get title for config entry."""
if model.endswith(":latest"):
model = model.split(":", maxsplit=1)[0]
return model

View File

@ -166,11 +166,14 @@ class OllamaBaseLLMEntity(Entity):
self.subentry = subentry
self._attr_name = subentry.title
self._attr_unique_id = subentry.subentry_id
model, _, version = subentry.data[CONF_MODEL].partition(":")
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
manufacturer="Ollama",
model=entry.data[CONF_MODEL],
model=model,
sw_version=version or "latest",
entry_type=dr.DeviceEntryType.SERVICE,
)

View File

@ -3,24 +3,17 @@
"step": {
"user": {
"data": {
"url": "[%key:common::config_flow::data::url%]",
"model": "Model"
"url": "[%key:common::config_flow::data::url%]"
}
},
"download": {
"title": "Downloading model"
}
},
"abort": {
"download_failed": "Model downloading failed",
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
},
"error": {
"invalid_url": "[%key:common::config_flow::error::invalid_host%]",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"unknown": "[%key:common::config_flow::error::unknown%]"
},
"progress": {
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
}
},
"config_subentries": {
@ -33,6 +26,7 @@
"step": {
"set_options": {
"data": {
"model": "Model",
"name": "[%key:common::config_flow::data::name%]",
"prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
@ -47,11 +41,19 @@
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
}
},
"download": {
"title": "Downloading model"
}
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
"entry_not_loaded": "Cannot add things while the configuration is disabled."
"entry_not_loaded": "Failed to add agent. The configuration is disabled.",
"download_failed": "Model downloading failed",
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
},
"progress": {
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
}
}
}

View File

@ -5,10 +5,10 @@ from homeassistant.helpers import llm
TEST_USER_DATA = {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test model",
}
TEST_OPTIONS = {
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
ollama.CONF_MAX_HISTORY: 2,
ollama.CONF_MODEL: "test_model:latest",
}

View File

@ -30,10 +30,11 @@ def mock_config_entry(
entry = MockConfigEntry(
domain=ollama.DOMAIN,
data=TEST_USER_DATA,
version=2,
version=3,
minor_version=1,
subentries_data=[
{
"data": mock_config_entry_options,
"data": {**TEST_OPTIONS, **mock_config_entry_options},
"subentry_type": "conversation",
"title": "Ollama Conversation",
"unique_id": None,
@ -49,10 +50,14 @@ def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry,
next(iter(mock_config_entry.subentries.values())),
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
subentry,
data={
**subentry.data,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
},
)
return mock_config_entry

View File

@ -8,6 +8,7 @@ import pytest
from homeassistant import config_entries
from homeassistant.components import ollama
from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@ -17,7 +18,7 @@ TEST_MODEL = "test_model:latest"
async def test_form(hass: HomeAssistant) -> None:
"""Test flow when the model is already downloaded."""
"""Test flow when configuring URL only."""
# Pretend we already set up a config entry.
hass.config.components.add(ollama.DOMAIN)
MockConfigEntry(
@ -34,7 +35,6 @@ async def test_form(hass: HomeAssistant) -> None:
with (
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
# test model is already "downloaded"
return_value={"models": [{"model": TEST_MODEL}]},
),
patch(
@ -42,24 +42,17 @@ async def test_form(hass: HomeAssistant) -> None:
return_value=True,
) as mock_setup_entry,
):
# Step 1: URL
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
await hass.async_block_till_done()
# Step 2: model
assert result2["type"] is FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
)
await hass.async_block_till_done()
assert result3["type"] is FlowResultType.CREATE_ENTRY
assert result3["data"] == {
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["data"] == {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: TEST_MODEL,
}
# No subentries created by default
assert len(result2.get("subentries", [])) == 0
assert len(mock_setup_entry.mock_calls) == 1
@ -94,98 +87,6 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
assert result["reason"] == "already_configured"
async def test_form_need_download(hass: HomeAssistant) -> None:
"""Test flow when a model needs to be downloaded."""
# Pretend we already set up a config entry.
hass.config.components.add(ollama.DOMAIN)
MockConfigEntry(
domain=ollama.DOMAIN,
state=config_entries.ConfigEntryState.LOADED,
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] is None
pull_ready = asyncio.Event()
pull_called = asyncio.Event()
pull_model: str | None = None
async def pull(self, model: str, *args, **kwargs) -> None:
nonlocal pull_model
async with asyncio.timeout(1):
await pull_ready.wait()
pull_model = model
pull_called.set()
with (
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
# No models are downloaded
return_value={},
),
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
pull,
),
patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
) as mock_setup_entry,
):
# Step 1: URL
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
await hass.async_block_till_done()
# Step 2: model
assert result2["type"] is FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
)
await hass.async_block_till_done()
# Step 3: download
assert result3["type"] is FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(
result3["flow_id"],
)
await hass.async_block_till_done()
# Run again without the task finishing.
# We should still be downloading.
assert result4["type"] is FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
)
await hass.async_block_till_done()
assert result4["type"] is FlowResultType.SHOW_PROGRESS
# Signal fake pull method to complete
pull_ready.set()
async with asyncio.timeout(1):
await pull_called.wait()
assert pull_model == TEST_MODEL
# Step 4: finish
result5 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
)
assert result5["type"] is FlowResultType.CREATE_ENTRY
assert result5["data"] == {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: TEST_MODEL,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_subentry_options(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
@ -193,34 +94,84 @@ async def test_subentry_options(
subentry = next(iter(mock_config_entry.subentries.values()))
# Test reconfiguration
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
with patch(
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
):
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
assert options_flow["type"] is FlowResultType.FORM
assert options_flow["step_id"] == "set_options"
assert options_flow["type"] is FlowResultType.FORM
assert options_flow["step_id"] == "set_options"
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"],
{
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768,
ollama.CONF_THINK: True,
},
)
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"],
{
ollama.CONF_MODEL: TEST_MODEL,
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768,
ollama.CONF_THINK: True,
},
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.ABORT
assert options["reason"] == "reconfigure_successful"
assert subentry.data == {
ollama.CONF_MODEL: TEST_MODEL,
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768,
ollama.CONF_MAX_HISTORY: 100.0,
ollama.CONF_NUM_CTX: 32768.0,
ollama.CONF_THINK: True,
}
async def test_creating_new_conversation_subentry(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test creating a new conversation subentry includes name field."""
# Start a new subentry flow
with patch(
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
assert new_flow["step_id"] == "set_options"
# Configure the new subentry with name field
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: TEST_MODEL,
CONF_NAME: "New Test Conversation",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "New Test Conversation"
assert result["data"] == {
ollama.CONF_MODEL: TEST_MODEL,
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50.0,
ollama.CONF_NUM_CTX: 16384.0,
ollama.CONF_THINK: False,
}
async def test_creating_conversation_subentry_not_loaded(
hass: HomeAssistant,
mock_init_component,
@ -237,6 +188,125 @@ async def test_creating_conversation_subentry_not_loaded(
assert result["reason"] == "entry_not_loaded"
async def test_subentry_need_download(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when model needs to be downloaded."""
async def delayed_pull(self, model: str) -> None:
"""Simulate a delayed model download."""
assert model == "llama3.2:latest"
await asyncio.sleep(0) # yield the event loop 1 iteration
with (
patch(
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
),
patch("ollama.AsyncClient.pull", delayed_pull),
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM, new_flow
assert new_flow["step_id"] == "set_options"
# Configure the new subentry with a model that needs downloading
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: "llama3.2:latest", # not cached
CONF_NAME: "New Test Conversation",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
assert result["type"] is FlowResultType.SHOW_PROGRESS
assert result["step_id"] == "download"
assert result["progress_action"] == "download"
await hass.async_block_till_done()
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"], {}
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "New Test Conversation"
assert result["data"] == {
ollama.CONF_MODEL: "llama3.2:latest",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50.0,
ollama.CONF_NUM_CTX: 16384.0,
ollama.CONF_THINK: False,
}
async def test_subentry_download_error(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when model download fails."""
async def delayed_pull(self, model: str) -> None:
"""Simulate a delayed model download."""
await asyncio.sleep(0) # yield
raise RuntimeError("Download failed")
with (
patch(
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
),
patch("ollama.AsyncClient.pull", delayed_pull),
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
assert new_flow["step_id"] == "set_options"
# Configure with a model that needs downloading but will fail
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: "llama3.2:latest",
CONF_NAME: "New Test Conversation",
ollama.CONF_PROMPT: "new test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
# Should show progress flow result for download
assert result["type"] is FlowResultType.SHOW_PROGRESS
assert result["step_id"] == "download"
assert result["progress_action"] == "download"
# Wait for download task to complete (with error)
await hass.async_block_till_done()
# Submit the progress flow - should get failure
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"], {}
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "download_failed"
@pytest.mark.parametrize(
("side_effect", "error"),
[
@ -262,40 +332,132 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
assert result2["errors"] == {"base": error}
async def test_download_error(hass: HomeAssistant) -> None:
"""Test we handle errors while downloading a model."""
async def test_form_invalid_url(hass: HomeAssistant) -> None:
"""Test we handle invalid URL."""
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
async def _delayed_runtime_error(*args, **kwargs):
await asyncio.sleep(0)
raise RuntimeError
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "not-a-valid-url"}
)
assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": "invalid_url"}
async def test_subentry_connection_error(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when connection to Ollama server fails."""
with patch(
"ollama.AsyncClient.list",
side_effect=ConnectError("Connection failed"),
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.ABORT
assert new_flow["reason"] == "cannot_connect"
async def test_subentry_model_check_exception(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test subentry creation when checking model availability throws exception."""
with patch(
"ollama.AsyncClient.list",
side_effect=[
{"models": [{"model": TEST_MODEL}]}, # First call succeeds
RuntimeError("Failed to check models"), # Second call fails
],
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
assert new_flow["step_id"] == "set_options"
# Configure with a model, should fail when checking availability
result = await hass.config_entries.subentries.async_configure(
new_flow["flow_id"],
{
ollama.CONF_MODEL: "new_model:latest",
CONF_NAME: "Test Conversation",
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 50,
ollama.CONF_NUM_CTX: 16384,
ollama.CONF_THINK: False,
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "cannot_connect"
async def test_subentry_reconfigure_with_download(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test reconfiguring subentry when model needs to be downloaded."""
subentry = next(iter(mock_config_entry.subentries.values()))
async def delayed_pull(self, model: str) -> None:
"""Simulate a delayed model download."""
assert model == "llama3.2:latest"
await asyncio.sleep(0) # yield the event loop
with (
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
return_value={},
),
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
_delayed_runtime_error,
"ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
),
patch("ollama.AsyncClient.pull", delayed_pull),
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
reconfigure_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
assert reconfigure_flow["type"] is FlowResultType.FORM
assert reconfigure_flow["step_id"] == "set_options"
# Reconfigure with a model that needs downloading
result = await hass.config_entries.subentries.async_configure(
reconfigure_flow["flow_id"],
{
ollama.CONF_MODEL: "llama3.2:latest",
ollama.CONF_PROMPT: "updated prompt",
ollama.CONF_MAX_HISTORY: 75,
ollama.CONF_NUM_CTX: 8192,
ollama.CONF_THINK: True,
},
)
assert result["type"] is FlowResultType.SHOW_PROGRESS
assert result["step_id"] == "download"
await hass.async_block_till_done()
assert result3["type"] is FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(result3["flow_id"])
await hass.async_block_till_done()
# Finish download
result = await hass.config_entries.subentries.async_configure(
reconfigure_flow["flow_id"], {}
)
assert result4["type"] is FlowResultType.ABORT
assert result4["reason"] == "download_failed"
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reconfigure_successful"
assert subentry.data == {
ollama.CONF_MODEL: "llama3.2:latest",
ollama.CONF_PROMPT: "updated prompt",
ollama.CONF_MAX_HISTORY: 75.0,
ollama.CONF_NUM_CTX: 8192.0,
ollama.CONF_THINK: True,
}

View File

@ -15,7 +15,12 @@ from homeassistant.components.conversation import trace
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
intent,
llm,
)
from tests.common import MockConfigEntry
@ -68,7 +73,7 @@ async def test_chat(
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert args["model"] == "test model"
assert args["model"] == "test_model:latest"
assert args["messages"] == [
Message(role="system", content=prompt),
Message(role="user", content="test message"),
@ -128,7 +133,7 @@ async def test_chat_stream(
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert args["model"] == "test model"
assert args["model"] == "test_model:latest"
assert args["messages"] == [
Message(role="system", content=prompt),
Message(role="user", content="test message"),
@ -158,6 +163,7 @@ async def test_template_variables(
"The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}."
),
ollama.CONF_MODEL: "test_model:latest",
},
)
with (
@ -524,7 +530,9 @@ async def test_message_history_unlimited(
):
subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0}
mock_config_entry,
subentry,
data={**subentry.data, ollama.CONF_MAX_HISTORY: 0},
)
for i in range(100):
result = await conversation.async_converse(
@ -573,6 +581,7 @@ async def test_template_error(
mock_config_entry,
subentry,
data={
**subentry.data,
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
@ -593,6 +602,8 @@ async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test OllamaConversationEntity."""
agent = conversation.get_agent_manager(hass).async_get_agent(
@ -604,6 +615,24 @@ async def test_conversation_agent(
assert state
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
entity_entry = entity_registry.async_get("conversation.ollama_conversation")
assert entity_entry
subentry = mock_config_entry.subentries.get(entity_entry.unique_id)
assert subentry
assert entity_entry.original_name == subentry.title
device_entry = device_registry.async_get(entity_entry.device_id)
assert device_entry
assert device_entry.identifiers == {(ollama.DOMAIN, subentry.subentry_id)}
assert device_entry.name == subentry.title
assert device_entry.manufacturer == "Ollama"
assert device_entry.entry_type == dr.DeviceEntryType.SERVICE
model, _, version = subentry.data[ollama.CONF_MODEL].partition(":")
assert device_entry.model == model
assert device_entry.sw_version == version
async def test_conversation_agent_with_assist(
hass: HomeAssistant,
@ -679,6 +708,7 @@ async def test_reasoning_filter(
mock_config_entry,
subentry,
data={
**subentry.data,
ollama.CONF_THINK: think,
},
)

View File

@ -9,13 +9,26 @@ from homeassistant.components import ollama
from homeassistant.components.ollama.const import DOMAIN
from homeassistant.config_entries import ConfigSubentryData
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers import device_registry as dr, entity_registry as er, llm
from homeassistant.setup import async_setup_component
from . import TEST_OPTIONS, TEST_USER_DATA
from . import TEST_OPTIONS
from tests.common import MockConfigEntry
V1_TEST_USER_DATA = {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model:latest",
}
V1_TEST_OPTIONS = {
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
ollama.CONF_MAX_HISTORY: 2,
}
V21_TEST_USER_DATA = V1_TEST_USER_DATA
V21_TEST_OPTIONS = V1_TEST_OPTIONS
@pytest.mark.parametrize(
("side_effect", "error"),
@ -41,17 +54,17 @@ async def test_init_error(
assert error in caplog.text
async def test_migration_from_v1_to_v2(
async def test_migration_from_v1(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2."""
"""Test migration from version 1."""
# Create a v1 config entry with conversation options and an entity
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data=TEST_USER_DATA,
options=TEST_OPTIONS,
data=V1_TEST_USER_DATA,
options=V1_TEST_OPTIONS,
version=1,
title="llama-3.2-8b",
)
@ -81,9 +94,10 @@ async def test_migration_from_v1_to_v2(
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
assert mock_config_entry.version == 2
assert mock_config_entry.minor_version == 2
assert mock_config_entry.data == TEST_USER_DATA
assert mock_config_entry.version == 3
assert mock_config_entry.minor_version == 1
# After migration, parent entry should only have URL
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 1
@ -92,7 +106,9 @@ async def test_migration_from_v1_to_v2(
assert subentry.unique_id is None
assert subentry.title == "llama-3.2-8b"
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
# Subentry should now include the model from the original options
expected_subentry_data = TEST_OPTIONS.copy()
assert subentry.data == expected_subentry_data
migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity is not None
@ -117,17 +133,17 @@ async def test_migration_from_v1_to_v2(
}
async def test_migration_from_v1_to_v2_with_multiple_urls(
async def test_migration_from_v1_with_multiple_urls(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with different URLs."""
"""Test migration from version 1 with different URLs."""
# Create two v1 config entries with different URLs
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
options=V1_TEST_OPTIONS,
version=1,
title="Ollama 1",
)
@ -135,7 +151,7 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
options=V1_TEST_OPTIONS,
version=1,
title="Ollama 2",
)
@ -187,13 +203,16 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
assert len(entries) == 2
for idx, entry in enumerate(entries):
assert entry.version == 2
assert entry.minor_version == 2
assert entry.version == 3
assert entry.minor_version == 1
assert not entry.options
assert len(entry.subentries) == 1
subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
# Subentry should include the model along with the original options
expected_subentry_data = TEST_OPTIONS.copy()
expected_subentry_data["model"] = "llama3.2:latest"
assert subentry.data == expected_subentry_data
assert subentry.title == f"Ollama {idx + 1}"
dev = device_registry.async_get_device(
@ -204,17 +223,17 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
async def test_migration_from_v1_to_v2_with_same_urls(
async def test_migration_from_v1_with_same_urls(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with same URLs consolidates entries."""
"""Test migration from version 1 with same URLs consolidates entries."""
# Create two v1 config entries with the same URL
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
options=V1_TEST_OPTIONS,
version=1,
title="Ollama",
)
@ -222,7 +241,7 @@ async def test_migration_from_v1_to_v2_with_same_urls(
mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
options=TEST_OPTIONS,
options=V1_TEST_OPTIONS,
version=1,
title="Ollama 2",
)
@ -275,8 +294,8 @@ async def test_migration_from_v1_to_v2_with_same_urls(
assert len(entries) == 1
entry = entries[0]
assert entry.version == 2
assert entry.minor_version == 2
assert entry.version == 3
assert entry.minor_version == 1
assert not entry.options
assert len(entry.subentries) == 2 # Two subentries from the two original entries
@ -288,7 +307,10 @@ async def test_migration_from_v1_to_v2_with_same_urls(
for subentry in subentries:
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
# Subentry should include the model along with the original options
expected_subentry_data = TEST_OPTIONS.copy()
expected_subentry_data["model"] = "llama3.2:latest"
assert subentry.data == expected_subentry_data
# Check devices were migrated correctly
dev = device_registry.async_get_device(
@ -301,12 +323,12 @@ async def test_migration_from_v1_to_v2_with_same_urls(
}
async def test_migration_from_v2_1_to_v2_2(
async def test_migration_from_v2_1(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 2.1 to version 2.2.
"""Test migration from version 2.1.
This tests we clean up the broken migration in Home Assistant Core
2025.7.0b0-2025.7.0b1:
@ -315,20 +337,20 @@ async def test_migration_from_v2_1_to_v2_2(
# Create a v2.1 config entry with 2 subentries, devices and entities
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data=TEST_USER_DATA,
data=V21_TEST_USER_DATA,
entry_id="mock_entry_id",
version=2,
minor_version=1,
subentries_data=[
ConfigSubentryData(
data=TEST_OPTIONS,
data=V21_TEST_OPTIONS,
subentry_id="mock_id_1",
subentry_type="conversation",
title="Ollama",
unique_id=None,
),
ConfigSubentryData(
data=TEST_OPTIONS,
data=V21_TEST_OPTIONS,
subentry_id="mock_id_2",
subentry_type="conversation",
title="Ollama 2",
@ -392,8 +414,8 @@ async def test_migration_from_v2_1_to_v2_2(
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.version == 2
assert entry.minor_version == 2
assert entry.version == 3
assert entry.minor_version == 1
assert not entry.options
assert entry.title == "Ollama"
assert len(entry.subentries) == 2
@ -405,6 +427,7 @@ async def test_migration_from_v2_1_to_v2_2(
assert len(conversation_subentries) == 2
for subentry in conversation_subentries:
assert subentry.subentry_type == "conversation"
# Since TEST_USER_DATA no longer has a model, subentry data should be TEST_OPTIONS
assert subentry.data == TEST_OPTIONS
assert "Ollama" in subentry.title
@ -450,3 +473,45 @@ async def test_migration_from_v2_1_to_v2_2(
assert device.config_entries_subentries == {
mock_config_entry.entry_id: {subentry.subentry_id}
}
async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
"""Test migration from version 2.2."""
subentry_data = ConfigSubentryData(
data=V21_TEST_USER_DATA,
subentry_type="conversation",
title="Test Conversation",
unique_id=None,
)
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model:latest", # Model still in main data
},
version=2,
minor_version=2,
subentries_data=[subentry_data],
)
mock_config_entry.add_to_hass(hass)
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
# Check migration to v3.1
assert mock_config_entry.version == 3
assert mock_config_entry.minor_version == 1
# Check that model was moved from main data to subentry
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
assert len(mock_config_entry.subentries) == 1
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.data == {
**V21_TEST_USER_DATA,
ollama.CONF_MODEL: "test_model:latest",
}