mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 06:17:07 +00:00
Ollama: Migrate pick model to subentry (#147944)
This commit is contained in:
parent
943fb9948b
commit
f50ef79c72
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
|
||||
import httpx
|
||||
import ollama
|
||||
@ -100,8 +101,12 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
||||
|
||||
for entry in entries:
|
||||
use_existing = False
|
||||
# Create subentry with model from entry.data and options from entry.options
|
||||
subentry_data = entry.options.copy()
|
||||
subentry_data[CONF_MODEL] = entry.data[CONF_MODEL]
|
||||
|
||||
subentry = ConfigSubentry(
|
||||
data=entry.options,
|
||||
data=MappingProxyType(subentry_data),
|
||||
subentry_type="conversation",
|
||||
title=entry.title,
|
||||
unique_id=None,
|
||||
@ -154,9 +159,11 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
||||
hass.config_entries.async_update_entry(
|
||||
entry,
|
||||
title=DEFAULT_NAME,
|
||||
# Update parent entry to only keep URL, remove model
|
||||
data={CONF_URL: entry.data[CONF_URL]},
|
||||
options={},
|
||||
version=2,
|
||||
minor_version=2,
|
||||
version=3,
|
||||
minor_version=1,
|
||||
)
|
||||
|
||||
|
||||
@ -164,7 +171,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
|
||||
"""Migrate entry."""
|
||||
_LOGGER.debug("Migrating from version %s:%s", entry.version, entry.minor_version)
|
||||
|
||||
if entry.version > 2:
|
||||
if entry.version > 3:
|
||||
# This means the user has downgraded from a future version
|
||||
return False
|
||||
|
||||
@ -182,6 +189,25 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
|
||||
|
||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||
|
||||
if entry.version == 2 and entry.minor_version == 2:
|
||||
# Update subentries to include the model
|
||||
for subentry in entry.subentries.values():
|
||||
if subentry.subentry_type == "conversation":
|
||||
updated_data = dict(subentry.data)
|
||||
updated_data[CONF_MODEL] = entry.data[CONF_MODEL]
|
||||
|
||||
hass.config_entries.async_update_subentry(
|
||||
entry, subentry, data=MappingProxyType(updated_data)
|
||||
)
|
||||
|
||||
# Update main entry to remove model and bump version
|
||||
hass.config_entries.async_update_entry(
|
||||
entry,
|
||||
data={CONF_URL: entry.data[CONF_URL]},
|
||||
version=3,
|
||||
minor_version=1,
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
||||
)
|
||||
|
@ -22,7 +22,7 @@ from homeassistant.config_entries import (
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers import config_validation as cv, llm
|
||||
from homeassistant.helpers.selector import (
|
||||
BooleanSelector,
|
||||
NumberSelector,
|
||||
@ -38,6 +38,7 @@ from homeassistant.helpers.selector import (
|
||||
)
|
||||
from homeassistant.util.ssl import get_default_context
|
||||
|
||||
from . import OllamaConfigEntry
|
||||
from .const import (
|
||||
CONF_KEEP_ALIVE,
|
||||
CONF_MAX_HISTORY,
|
||||
@ -72,43 +73,43 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Ollama."""
|
||||
|
||||
VERSION = 2
|
||||
MINOR_VERSION = 2
|
||||
VERSION = 3
|
||||
MINOR_VERSION = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize config flow."""
|
||||
self.url: str | None = None
|
||||
self.model: str | None = None
|
||||
self.client: ollama.AsyncClient | None = None
|
||||
self.download_task: asyncio.Task | None = None
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
user_input = user_input or {}
|
||||
self.url = user_input.get(CONF_URL, self.url)
|
||||
self.model = user_input.get(CONF_MODEL, self.model)
|
||||
|
||||
if self.url is None:
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
||||
)
|
||||
|
||||
errors = {}
|
||||
url = user_input[CONF_URL]
|
||||
|
||||
self._async_abort_entries_match({CONF_URL: self.url})
|
||||
self._async_abort_entries_match({CONF_URL: url})
|
||||
|
||||
try:
|
||||
self.client = ollama.AsyncClient(
|
||||
host=self.url, verify=get_default_context()
|
||||
url = cv.url(url)
|
||||
except vol.Invalid:
|
||||
errors["base"] = "invalid_url"
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=self.add_suggested_values_to_schema(
|
||||
STEP_USER_DATA_SCHEMA, user_input
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
response = await self.client.list()
|
||||
|
||||
downloaded_models: set[str] = {
|
||||
model_info["model"] for model_info in response.get("models", [])
|
||||
}
|
||||
try:
|
||||
client = ollama.AsyncClient(host=url, verify=get_default_context())
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
await client.list()
|
||||
except (TimeoutError, httpx.ConnectError):
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception:
|
||||
@ -117,10 +118,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
|
||||
if errors:
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
||||
step_id="user",
|
||||
data_schema=self.add_suggested_values_to_schema(
|
||||
STEP_USER_DATA_SCHEMA, user_input
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
if self.model is None:
|
||||
return self.async_create_entry(
|
||||
title=url,
|
||||
data={CONF_URL: url},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@callback
|
||||
def async_get_supported_subentry_types(
|
||||
cls, config_entry: ConfigEntry
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {"conversation": ConversationSubentryFlowHandler}
|
||||
|
||||
|
||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing conversation subentries."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the subentry flow."""
|
||||
super().__init__()
|
||||
self._name: str | None = None
|
||||
self._model: str | None = None
|
||||
self.download_task: asyncio.Task | None = None
|
||||
self._config_data: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
def _is_new(self) -> bool:
|
||||
"""Return if this is a new subentry."""
|
||||
return self.source == "user"
|
||||
|
||||
@property
|
||||
def _client(self) -> ollama.AsyncClient:
|
||||
"""Return the Ollama client."""
|
||||
entry: OllamaConfigEntry = self._get_entry()
|
||||
return entry.runtime_data
|
||||
|
||||
async def async_step_set_options(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Handle model selection and configuration step."""
|
||||
if self._get_entry().state != ConfigEntryState.LOADED:
|
||||
return self.async_abort(reason="entry_not_loaded")
|
||||
|
||||
if user_input is None:
|
||||
# Get available models from Ollama server
|
||||
try:
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
response = await self._client.list()
|
||||
|
||||
downloaded_models: set[str] = {
|
||||
model_info["model"] for model_info in response.get("models", [])
|
||||
}
|
||||
except (TimeoutError, httpx.ConnectError, httpx.HTTPError):
|
||||
_LOGGER.exception("Failed to get models from Ollama server")
|
||||
return self.async_abort(reason="cannot_connect")
|
||||
|
||||
# Show models that have been downloaded first, followed by all known
|
||||
# models (only latest tags).
|
||||
models_to_list = [
|
||||
@ -131,52 +191,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
for m in sorted(MODEL_NAMES)
|
||||
if m not in downloaded_models
|
||||
]
|
||||
model_step_schema = vol.Schema(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_MODEL, description={"suggested_value": DEFAULT_MODEL}
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=models_to_list, custom_value=True)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if self._is_new:
|
||||
options = {}
|
||||
else:
|
||||
options = self._get_reconfigure_subentry().data.copy()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=model_step_schema,
|
||||
step_id="set_options",
|
||||
data_schema=vol.Schema(
|
||||
ollama_config_option_schema(
|
||||
self.hass, self._is_new, options, models_to_list
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if self.model not in downloaded_models:
|
||||
self._model = user_input[CONF_MODEL]
|
||||
if self._is_new:
|
||||
self._name = user_input.pop(CONF_NAME)
|
||||
|
||||
# Check if model needs to be downloaded
|
||||
try:
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
response = await self._client.list()
|
||||
|
||||
currently_downloaded_models: set[str] = {
|
||||
model_info["model"] for model_info in response.get("models", [])
|
||||
}
|
||||
|
||||
if self._model not in currently_downloaded_models:
|
||||
# Store the user input to use after download
|
||||
self._config_data = user_input
|
||||
# Ollama server needs to download model first
|
||||
return await self.async_step_download()
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to check model availability")
|
||||
return self.async_abort(reason="cannot_connect")
|
||||
|
||||
# Model is already downloaded, create/update the entry
|
||||
if self._is_new:
|
||||
return self.async_create_entry(
|
||||
title=self.url,
|
||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
||||
subentries=[
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": {},
|
||||
"title": _get_title(self.model),
|
||||
"unique_id": None,
|
||||
}
|
||||
],
|
||||
title=self._name,
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
return self.async_update_and_abort(
|
||||
self._get_entry(),
|
||||
self._get_reconfigure_subentry(),
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
async def async_step_download(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
) -> SubentryFlowResult:
|
||||
"""Step to wait for Ollama server to download a model."""
|
||||
assert self.model is not None
|
||||
assert self.client is not None
|
||||
assert self._model is not None
|
||||
|
||||
if self.download_task is None:
|
||||
# Tell Ollama server to pull the model.
|
||||
# The task will block until the model and metadata are fully
|
||||
# downloaded.
|
||||
self.download_task = self.hass.async_create_background_task(
|
||||
self.client.pull(self.model),
|
||||
f"Downloading {self.model}",
|
||||
self._client.pull(self._model),
|
||||
f"Downloading {self._model}",
|
||||
)
|
||||
|
||||
if self.download_task.done():
|
||||
@ -192,80 +269,28 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
progress_task=self.download_task,
|
||||
)
|
||||
|
||||
async def async_step_finish(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Step after model downloading has succeeded."""
|
||||
assert self.url is not None
|
||||
assert self.model is not None
|
||||
|
||||
return self.async_create_entry(
|
||||
title=_get_title(self.model),
|
||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
||||
subentries=[
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": {},
|
||||
"title": _get_title(self.model),
|
||||
"unique_id": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
async def async_step_failed(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
) -> SubentryFlowResult:
|
||||
"""Step after model downloading has failed."""
|
||||
return self.async_abort(reason="download_failed")
|
||||
|
||||
@classmethod
|
||||
@callback
|
||||
def async_get_supported_subentry_types(
|
||||
cls, config_entry: ConfigEntry
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {"conversation": ConversationSubentryFlowHandler}
|
||||
|
||||
|
||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing conversation subentries."""
|
||||
|
||||
@property
|
||||
def _is_new(self) -> bool:
|
||||
"""Return if this is a new subentry."""
|
||||
return self.source == "user"
|
||||
|
||||
async def async_step_set_options(
|
||||
async def async_step_finish(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Set conversation options."""
|
||||
# abort if entry is not loaded
|
||||
if self._get_entry().state != ConfigEntryState.LOADED:
|
||||
return self.async_abort(reason="entry_not_loaded")
|
||||
"""Step after model downloading has succeeded."""
|
||||
assert self._config_data is not None
|
||||
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input is None:
|
||||
# Model download completed, create/update the entry with stored config
|
||||
if self._is_new:
|
||||
options = {}
|
||||
else:
|
||||
options = self._get_reconfigure_subentry().data.copy()
|
||||
|
||||
elif self._is_new:
|
||||
return self.async_create_entry(
|
||||
title=user_input.pop(CONF_NAME),
|
||||
data=user_input,
|
||||
title=self._name,
|
||||
data=self._config_data,
|
||||
)
|
||||
else:
|
||||
return self.async_update_and_abort(
|
||||
self._get_entry(),
|
||||
self._get_reconfigure_subentry(),
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
schema = ollama_config_option_schema(self.hass, self._is_new, options)
|
||||
return self.async_show_form(
|
||||
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
||||
data=self._config_data,
|
||||
)
|
||||
|
||||
async_step_user = async_step_set_options
|
||||
@ -273,19 +298,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
|
||||
def ollama_config_option_schema(
|
||||
hass: HomeAssistant, is_new: bool, options: Mapping[str, Any]
|
||||
hass: HomeAssistant,
|
||||
is_new: bool,
|
||||
options: Mapping[str, Any],
|
||||
models_to_list: list[SelectOptionDict],
|
||||
) -> dict:
|
||||
"""Ollama options schema."""
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
]
|
||||
|
||||
if is_new:
|
||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||
schema: dict = {
|
||||
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
|
||||
}
|
||||
else:
|
||||
@ -293,6 +313,12 @@ def ollama_config_option_schema(
|
||||
|
||||
schema.update(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_MODEL,
|
||||
description={"suggested_value": options.get(CONF_MODEL, DEFAULT_MODEL)},
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=models_to_list, custom_value=True)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
@ -304,7 +330,18 @@ def ollama_config_option_schema(
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
],
|
||||
multiple=True,
|
||||
)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_NUM_CTX,
|
||||
description={
|
||||
@ -350,11 +387,3 @@ def ollama_config_option_schema(
|
||||
)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _get_title(model: str) -> str:
|
||||
"""Get title for config entry."""
|
||||
if model.endswith(":latest"):
|
||||
model = model.split(":", maxsplit=1)[0]
|
||||
|
||||
return model
|
||||
|
@ -166,11 +166,14 @@ class OllamaBaseLLMEntity(Entity):
|
||||
self.subentry = subentry
|
||||
self._attr_name = subentry.title
|
||||
self._attr_unique_id = subentry.subentry_id
|
||||
|
||||
model, _, version = subentry.data[CONF_MODEL].partition(":")
|
||||
self._attr_device_info = dr.DeviceInfo(
|
||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||
name=subentry.title,
|
||||
manufacturer="Ollama",
|
||||
model=entry.data[CONF_MODEL],
|
||||
model=model,
|
||||
sw_version=version or "latest",
|
||||
entry_type=dr.DeviceEntryType.SERVICE,
|
||||
)
|
||||
|
||||
|
@ -3,24 +3,17 @@
|
||||
"step": {
|
||||
"user": {
|
||||
"data": {
|
||||
"url": "[%key:common::config_flow::data::url%]",
|
||||
"model": "Model"
|
||||
"url": "[%key:common::config_flow::data::url%]"
|
||||
}
|
||||
},
|
||||
"download": {
|
||||
"title": "Downloading model"
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"download_failed": "Model downloading failed",
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
|
||||
},
|
||||
"error": {
|
||||
"invalid_url": "[%key:common::config_flow::error::invalid_host%]",
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
},
|
||||
"progress": {
|
||||
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
||||
}
|
||||
},
|
||||
"config_subentries": {
|
||||
@ -33,6 +26,7 @@
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
"model": "Model",
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"prompt": "Instructions",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||
@ -47,11 +41,19 @@
|
||||
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
|
||||
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
|
||||
}
|
||||
},
|
||||
"download": {
|
||||
"title": "Downloading model"
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
||||
"entry_not_loaded": "Cannot add things while the configuration is disabled."
|
||||
"entry_not_loaded": "Failed to add agent. The configuration is disabled.",
|
||||
"download_failed": "Model downloading failed",
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
|
||||
},
|
||||
"progress": {
|
||||
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,10 +5,10 @@ from homeassistant.helpers import llm
|
||||
|
||||
TEST_USER_DATA = {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: "test model",
|
||||
}
|
||||
|
||||
TEST_OPTIONS = {
|
||||
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
ollama.CONF_MAX_HISTORY: 2,
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
}
|
||||
|
@ -30,10 +30,11 @@ def mock_config_entry(
|
||||
entry = MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
data=TEST_USER_DATA,
|
||||
version=2,
|
||||
version=3,
|
||||
minor_version=1,
|
||||
subentries_data=[
|
||||
{
|
||||
"data": mock_config_entry_options,
|
||||
"data": {**TEST_OPTIONS, **mock_config_entry_options},
|
||||
"subentry_type": "conversation",
|
||||
"title": "Ollama Conversation",
|
||||
"unique_id": None,
|
||||
@ -49,10 +50,14 @@ def mock_config_entry_with_assist(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with assist."""
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
hass.config_entries.async_update_subentry(
|
||||
mock_config_entry,
|
||||
next(iter(mock_config_entry.subentries.values())),
|
||||
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
||||
subentry,
|
||||
data={
|
||||
**subentry.data,
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
},
|
||||
)
|
||||
return mock_config_entry
|
||||
|
||||
|
@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.const import CONF_NAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
@ -17,7 +18,7 @@ TEST_MODEL = "test_model:latest"
|
||||
|
||||
|
||||
async def test_form(hass: HomeAssistant) -> None:
|
||||
"""Test flow when the model is already downloaded."""
|
||||
"""Test flow when configuring URL only."""
|
||||
# Pretend we already set up a config entry.
|
||||
hass.config.components.add(ollama.DOMAIN)
|
||||
MockConfigEntry(
|
||||
@ -34,7 +35,6 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
# test model is already "downloaded"
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
),
|
||||
patch(
|
||||
@ -42,24 +42,17 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
return_value=True,
|
||||
) as mock_setup_entry,
|
||||
):
|
||||
# Step 1: URL
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Step 2: model
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result3["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result3["data"] == {
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["data"] == {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
}
|
||||
# No subentries created by default
|
||||
assert len(result2.get("subentries", [])) == 0
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
@ -94,98 +87,6 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_form_need_download(hass: HomeAssistant) -> None:
|
||||
"""Test flow when a model needs to be downloaded."""
|
||||
# Pretend we already set up a config entry.
|
||||
hass.config.components.add(ollama.DOMAIN)
|
||||
MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
state=config_entries.ConfigEntryState.LOADED,
|
||||
).add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] is None
|
||||
|
||||
pull_ready = asyncio.Event()
|
||||
pull_called = asyncio.Event()
|
||||
pull_model: str | None = None
|
||||
|
||||
async def pull(self, model: str, *args, **kwargs) -> None:
|
||||
nonlocal pull_model
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await pull_ready.wait()
|
||||
|
||||
pull_model = model
|
||||
pull_called.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
# No models are downloaded
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
|
||||
pull,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.ollama.async_setup_entry",
|
||||
return_value=True,
|
||||
) as mock_setup_entry,
|
||||
):
|
||||
# Step 1: URL
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Step 2: model
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Step 3: download
|
||||
assert result3["type"] is FlowResultType.SHOW_PROGRESS
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result3["flow_id"],
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Run again without the task finishing.
|
||||
# We should still be downloading.
|
||||
assert result4["type"] is FlowResultType.SHOW_PROGRESS
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result4["type"] is FlowResultType.SHOW_PROGRESS
|
||||
|
||||
# Signal fake pull method to complete
|
||||
pull_ready.set()
|
||||
async with asyncio.timeout(1):
|
||||
await pull_called.wait()
|
||||
|
||||
assert pull_model == TEST_MODEL
|
||||
|
||||
# Step 4: finish
|
||||
result5 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
)
|
||||
|
||||
assert result5["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result5["data"] == {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_subentry_options(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
@ -193,6 +94,10 @@ async def test_subentry_options(
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
|
||||
# Test reconfiguration
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
):
|
||||
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||
hass, subentry.subentry_id
|
||||
)
|
||||
@ -203,6 +108,7 @@ async def test_subentry_options(
|
||||
options = await hass.config_entries.subentries.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
ollama.CONF_PROMPT: "test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 100,
|
||||
ollama.CONF_NUM_CTX: 32768,
|
||||
@ -214,13 +120,58 @@ async def test_subentry_options(
|
||||
assert options["type"] is FlowResultType.ABORT
|
||||
assert options["reason"] == "reconfigure_successful"
|
||||
assert subentry.data == {
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
ollama.CONF_PROMPT: "test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 100,
|
||||
ollama.CONF_NUM_CTX: 32768,
|
||||
ollama.CONF_MAX_HISTORY: 100.0,
|
||||
ollama.CONF_NUM_CTX: 32768.0,
|
||||
ollama.CONF_THINK: True,
|
||||
}
|
||||
|
||||
|
||||
async def test_creating_new_conversation_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test creating a new conversation subentry includes name field."""
|
||||
# Start a new subentry flow
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM
|
||||
assert new_flow["step_id"] == "set_options"
|
||||
|
||||
# Configure the new subentry with name field
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
new_flow["flow_id"],
|
||||
{
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
CONF_NAME: "New Test Conversation",
|
||||
ollama.CONF_PROMPT: "new test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 50,
|
||||
ollama.CONF_NUM_CTX: 16384,
|
||||
ollama.CONF_THINK: False,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "New Test Conversation"
|
||||
assert result["data"] == {
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
ollama.CONF_PROMPT: "new test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 50.0,
|
||||
ollama.CONF_NUM_CTX: 16384.0,
|
||||
ollama.CONF_THINK: False,
|
||||
}
|
||||
|
||||
|
||||
async def test_creating_conversation_subentry_not_loaded(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component,
|
||||
@ -237,6 +188,125 @@ async def test_creating_conversation_subentry_not_loaded(
|
||||
assert result["reason"] == "entry_not_loaded"
|
||||
|
||||
|
||||
async def test_subentry_need_download(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test subentry creation when model needs to be downloaded."""
|
||||
|
||||
async def delayed_pull(self, model: str) -> None:
|
||||
"""Simulate a delayed model download."""
|
||||
assert model == "llama3.2:latest"
|
||||
await asyncio.sleep(0) # yield the event loop 1 iteration
|
||||
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
),
|
||||
patch("ollama.AsyncClient.pull", delayed_pull),
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM, new_flow
|
||||
assert new_flow["step_id"] == "set_options"
|
||||
|
||||
# Configure the new subentry with a model that needs downloading
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
new_flow["flow_id"],
|
||||
{
|
||||
ollama.CONF_MODEL: "llama3.2:latest", # not cached
|
||||
CONF_NAME: "New Test Conversation",
|
||||
ollama.CONF_PROMPT: "new test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 50,
|
||||
ollama.CONF_NUM_CTX: 16384,
|
||||
ollama.CONF_THINK: False,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.SHOW_PROGRESS
|
||||
assert result["step_id"] == "download"
|
||||
assert result["progress_action"] == "download"
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
new_flow["flow_id"], {}
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "New Test Conversation"
|
||||
assert result["data"] == {
|
||||
ollama.CONF_MODEL: "llama3.2:latest",
|
||||
ollama.CONF_PROMPT: "new test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 50.0,
|
||||
ollama.CONF_NUM_CTX: 16384.0,
|
||||
ollama.CONF_THINK: False,
|
||||
}
|
||||
|
||||
|
||||
async def test_subentry_download_error(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test subentry creation when model download fails."""
|
||||
|
||||
async def delayed_pull(self, model: str) -> None:
|
||||
"""Simulate a delayed model download."""
|
||||
await asyncio.sleep(0) # yield
|
||||
|
||||
raise RuntimeError("Download failed")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
),
|
||||
patch("ollama.AsyncClient.pull", delayed_pull),
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM
|
||||
assert new_flow["step_id"] == "set_options"
|
||||
|
||||
# Configure with a model that needs downloading but will fail
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
new_flow["flow_id"],
|
||||
{
|
||||
ollama.CONF_MODEL: "llama3.2:latest",
|
||||
CONF_NAME: "New Test Conversation",
|
||||
ollama.CONF_PROMPT: "new test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 50,
|
||||
ollama.CONF_NUM_CTX: 16384,
|
||||
ollama.CONF_THINK: False,
|
||||
},
|
||||
)
|
||||
|
||||
# Should show progress flow result for download
|
||||
assert result["type"] is FlowResultType.SHOW_PROGRESS
|
||||
assert result["step_id"] == "download"
|
||||
assert result["progress_action"] == "download"
|
||||
|
||||
# Wait for download task to complete (with error)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Submit the progress flow - should get failure
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
new_flow["flow_id"], {}
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "download_failed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
@ -262,40 +332,132 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
|
||||
assert result2["errors"] == {"base": error}
|
||||
|
||||
|
||||
async def test_download_error(hass: HomeAssistant) -> None:
|
||||
"""Test we handle errors while downloading a model."""
|
||||
async def test_form_invalid_url(hass: HomeAssistant) -> None:
|
||||
"""Test we handle invalid URL."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
async def _delayed_runtime_error(*args, **kwargs):
|
||||
await asyncio.sleep(0)
|
||||
raise RuntimeError
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "not-a-valid-url"}
|
||||
)
|
||||
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
assert result2["errors"] == {"base": "invalid_url"}
|
||||
|
||||
|
||||
async def test_subentry_connection_error(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test subentry creation when connection to Ollama server fails."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
side_effect=ConnectError("Connection failed"),
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.ABORT
|
||||
assert new_flow["reason"] == "cannot_connect"
|
||||
|
||||
|
||||
async def test_subentry_model_check_exception(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test subentry creation when checking model availability throws exception."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
side_effect=[
|
||||
{"models": [{"model": TEST_MODEL}]}, # First call succeeds
|
||||
RuntimeError("Failed to check models"), # Second call fails
|
||||
],
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM
|
||||
assert new_flow["step_id"] == "set_options"
|
||||
|
||||
# Configure with a model, should fail when checking availability
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
new_flow["flow_id"],
|
||||
{
|
||||
ollama.CONF_MODEL: "new_model:latest",
|
||||
CONF_NAME: "Test Conversation",
|
||||
ollama.CONF_PROMPT: "test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 50,
|
||||
ollama.CONF_NUM_CTX: 16384,
|
||||
ollama.CONF_THINK: False,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "cannot_connect"
|
||||
|
||||
|
||||
async def test_subentry_reconfigure_with_download(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test reconfiguring subentry when model needs to be downloaded."""
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
|
||||
async def delayed_pull(self, model: str) -> None:
|
||||
"""Simulate a delayed model download."""
|
||||
assert model == "llama3.2:latest"
|
||||
await asyncio.sleep(0) # yield the event loop
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
|
||||
_delayed_runtime_error,
|
||||
"ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
),
|
||||
patch("ollama.AsyncClient.pull", delayed_pull),
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
reconfigure_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||
hass, subentry.subentry_id
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
||||
assert reconfigure_flow["type"] is FlowResultType.FORM
|
||||
assert reconfigure_flow["step_id"] == "set_options"
|
||||
|
||||
# Reconfigure with a model that needs downloading
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
reconfigure_flow["flow_id"],
|
||||
{
|
||||
ollama.CONF_MODEL: "llama3.2:latest",
|
||||
ollama.CONF_PROMPT: "updated prompt",
|
||||
ollama.CONF_MAX_HISTORY: 75,
|
||||
ollama.CONF_NUM_CTX: 8192,
|
||||
ollama.CONF_THINK: True,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.SHOW_PROGRESS
|
||||
assert result["step_id"] == "download"
|
||||
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result3["type"] is FlowResultType.SHOW_PROGRESS
|
||||
result4 = await hass.config_entries.flow.async_configure(result3["flow_id"])
|
||||
await hass.async_block_till_done()
|
||||
# Finish download
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
reconfigure_flow["flow_id"], {}
|
||||
)
|
||||
|
||||
assert result4["type"] is FlowResultType.ABORT
|
||||
assert result4["reason"] == "download_failed"
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reconfigure_successful"
|
||||
assert subentry.data == {
|
||||
ollama.CONF_MODEL: "llama3.2:latest",
|
||||
ollama.CONF_PROMPT: "updated prompt",
|
||||
ollama.CONF_MAX_HISTORY: 75.0,
|
||||
ollama.CONF_NUM_CTX: 8192.0,
|
||||
ollama.CONF_THINK: True,
|
||||
}
|
||||
|
@ -15,7 +15,12 @@ from homeassistant.components.conversation import trace
|
||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import intent, llm
|
||||
from homeassistant.helpers import (
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
llm,
|
||||
)
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -68,7 +73,7 @@ async def test_chat(
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert args["model"] == "test model"
|
||||
assert args["model"] == "test_model:latest"
|
||||
assert args["messages"] == [
|
||||
Message(role="system", content=prompt),
|
||||
Message(role="user", content="test message"),
|
||||
@ -128,7 +133,7 @@ async def test_chat_stream(
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert args["model"] == "test model"
|
||||
assert args["model"] == "test_model:latest"
|
||||
assert args["messages"] == [
|
||||
Message(role="system", content=prompt),
|
||||
Message(role="user", content="test message"),
|
||||
@ -158,6 +163,7 @@ async def test_template_variables(
|
||||
"The user name is {{ user_name }}. "
|
||||
"The user id is {{ llm_context.context.user_id }}."
|
||||
),
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
},
|
||||
)
|
||||
with (
|
||||
@ -524,7 +530,9 @@ async def test_message_history_unlimited(
|
||||
):
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
hass.config_entries.async_update_subentry(
|
||||
mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0}
|
||||
mock_config_entry,
|
||||
subentry,
|
||||
data={**subentry.data, ollama.CONF_MAX_HISTORY: 0},
|
||||
)
|
||||
for i in range(100):
|
||||
result = await conversation.async_converse(
|
||||
@ -573,6 +581,7 @@ async def test_template_error(
|
||||
mock_config_entry,
|
||||
subentry,
|
||||
data={
|
||||
**subentry.data,
|
||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||
},
|
||||
)
|
||||
@ -593,6 +602,8 @@ async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test OllamaConversationEntity."""
|
||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||
@ -604,6 +615,24 @@ async def test_conversation_agent(
|
||||
assert state
|
||||
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
|
||||
|
||||
entity_entry = entity_registry.async_get("conversation.ollama_conversation")
|
||||
assert entity_entry
|
||||
subentry = mock_config_entry.subentries.get(entity_entry.unique_id)
|
||||
assert subentry
|
||||
assert entity_entry.original_name == subentry.title
|
||||
|
||||
device_entry = device_registry.async_get(entity_entry.device_id)
|
||||
assert device_entry
|
||||
|
||||
assert device_entry.identifiers == {(ollama.DOMAIN, subentry.subentry_id)}
|
||||
assert device_entry.name == subentry.title
|
||||
assert device_entry.manufacturer == "Ollama"
|
||||
assert device_entry.entry_type == dr.DeviceEntryType.SERVICE
|
||||
|
||||
model, _, version = subentry.data[ollama.CONF_MODEL].partition(":")
|
||||
assert device_entry.model == model
|
||||
assert device_entry.sw_version == version
|
||||
|
||||
|
||||
async def test_conversation_agent_with_assist(
|
||||
hass: HomeAssistant,
|
||||
@ -679,6 +708,7 @@ async def test_reasoning_filter(
|
||||
mock_config_entry,
|
||||
subentry,
|
||||
data={
|
||||
**subentry.data,
|
||||
ollama.CONF_THINK: think,
|
||||
},
|
||||
)
|
||||
|
@ -9,13 +9,26 @@ from homeassistant.components import ollama
|
||||
from homeassistant.components.ollama.const import DOMAIN
|
||||
from homeassistant.config_entries import ConfigSubentryData
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er, llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import TEST_OPTIONS, TEST_USER_DATA
|
||||
from . import TEST_OPTIONS
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
V1_TEST_USER_DATA = {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
}
|
||||
|
||||
V1_TEST_OPTIONS = {
|
||||
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
ollama.CONF_MAX_HISTORY: 2,
|
||||
}
|
||||
|
||||
V21_TEST_USER_DATA = V1_TEST_USER_DATA
|
||||
V21_TEST_OPTIONS = V1_TEST_OPTIONS
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
@ -41,17 +54,17 @@ async def test_init_error(
|
||||
assert error in caplog.text
|
||||
|
||||
|
||||
async def test_migration_from_v1_to_v2(
|
||||
async def test_migration_from_v1(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test migration from version 1 to version 2."""
|
||||
"""Test migration from version 1."""
|
||||
# Create a v1 config entry with conversation options and an entity
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=TEST_USER_DATA,
|
||||
options=TEST_OPTIONS,
|
||||
data=V1_TEST_USER_DATA,
|
||||
options=V1_TEST_OPTIONS,
|
||||
version=1,
|
||||
title="llama-3.2-8b",
|
||||
)
|
||||
@ -81,9 +94,10 @@ async def test_migration_from_v1_to_v2(
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
|
||||
assert mock_config_entry.version == 2
|
||||
assert mock_config_entry.minor_version == 2
|
||||
assert mock_config_entry.data == TEST_USER_DATA
|
||||
assert mock_config_entry.version == 3
|
||||
assert mock_config_entry.minor_version == 1
|
||||
# After migration, parent entry should only have URL
|
||||
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
assert mock_config_entry.options == {}
|
||||
|
||||
assert len(mock_config_entry.subentries) == 1
|
||||
@ -92,7 +106,9 @@ async def test_migration_from_v1_to_v2(
|
||||
assert subentry.unique_id is None
|
||||
assert subentry.title == "llama-3.2-8b"
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == TEST_OPTIONS
|
||||
# Subentry should now include the model from the original options
|
||||
expected_subentry_data = TEST_OPTIONS.copy()
|
||||
assert subentry.data == expected_subentry_data
|
||||
|
||||
migrated_entity = entity_registry.async_get(entity.entity_id)
|
||||
assert migrated_entity is not None
|
||||
@ -117,17 +133,17 @@ async def test_migration_from_v1_to_v2(
|
||||
}
|
||||
|
||||
|
||||
async def test_migration_from_v1_to_v2_with_multiple_urls(
|
||||
async def test_migration_from_v1_with_multiple_urls(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test migration from version 1 to version 2 with different URLs."""
|
||||
"""Test migration from version 1 with different URLs."""
|
||||
# Create two v1 config entries with different URLs
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
||||
options=TEST_OPTIONS,
|
||||
options=V1_TEST_OPTIONS,
|
||||
version=1,
|
||||
title="Ollama 1",
|
||||
)
|
||||
@ -135,7 +151,7 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
|
||||
mock_config_entry_2 = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
|
||||
options=TEST_OPTIONS,
|
||||
options=V1_TEST_OPTIONS,
|
||||
version=1,
|
||||
title="Ollama 2",
|
||||
)
|
||||
@ -187,13 +203,16 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
|
||||
assert len(entries) == 2
|
||||
|
||||
for idx, entry in enumerate(entries):
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 2
|
||||
assert entry.version == 3
|
||||
assert entry.minor_version == 1
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 1
|
||||
subentry = list(entry.subentries.values())[0]
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == TEST_OPTIONS
|
||||
# Subentry should include the model along with the original options
|
||||
expected_subentry_data = TEST_OPTIONS.copy()
|
||||
expected_subentry_data["model"] = "llama3.2:latest"
|
||||
assert subentry.data == expected_subentry_data
|
||||
assert subentry.title == f"Ollama {idx + 1}"
|
||||
|
||||
dev = device_registry.async_get_device(
|
||||
@ -204,17 +223,17 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
|
||||
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
|
||||
|
||||
|
||||
async def test_migration_from_v1_to_v2_with_same_urls(
|
||||
async def test_migration_from_v1_with_same_urls(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test migration from version 1 to version 2 with same URLs consolidates entries."""
|
||||
"""Test migration from version 1 with same URLs consolidates entries."""
|
||||
# Create two v1 config entries with the same URL
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
||||
options=TEST_OPTIONS,
|
||||
options=V1_TEST_OPTIONS,
|
||||
version=1,
|
||||
title="Ollama",
|
||||
)
|
||||
@ -222,7 +241,7 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
||||
mock_config_entry_2 = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
|
||||
options=TEST_OPTIONS,
|
||||
options=V1_TEST_OPTIONS,
|
||||
version=1,
|
||||
title="Ollama 2",
|
||||
)
|
||||
@ -275,8 +294,8 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
||||
assert len(entries) == 1
|
||||
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 2
|
||||
assert entry.version == 3
|
||||
assert entry.minor_version == 1
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 2 # Two subentries from the two original entries
|
||||
|
||||
@ -288,7 +307,10 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
||||
|
||||
for subentry in subentries:
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == TEST_OPTIONS
|
||||
# Subentry should include the model along with the original options
|
||||
expected_subentry_data = TEST_OPTIONS.copy()
|
||||
expected_subentry_data["model"] = "llama3.2:latest"
|
||||
assert subentry.data == expected_subentry_data
|
||||
|
||||
# Check devices were migrated correctly
|
||||
dev = device_registry.async_get_device(
|
||||
@ -301,12 +323,12 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
||||
}
|
||||
|
||||
|
||||
async def test_migration_from_v2_1_to_v2_2(
|
||||
async def test_migration_from_v2_1(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test migration from version 2.1 to version 2.2.
|
||||
"""Test migration from version 2.1.
|
||||
|
||||
This tests we clean up the broken migration in Home Assistant Core
|
||||
2025.7.0b0-2025.7.0b1:
|
||||
@ -315,20 +337,20 @@ async def test_migration_from_v2_1_to_v2_2(
|
||||
# Create a v2.1 config entry with 2 subentries, devices and entities
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=TEST_USER_DATA,
|
||||
data=V21_TEST_USER_DATA,
|
||||
entry_id="mock_entry_id",
|
||||
version=2,
|
||||
minor_version=1,
|
||||
subentries_data=[
|
||||
ConfigSubentryData(
|
||||
data=TEST_OPTIONS,
|
||||
data=V21_TEST_OPTIONS,
|
||||
subentry_id="mock_id_1",
|
||||
subentry_type="conversation",
|
||||
title="Ollama",
|
||||
unique_id=None,
|
||||
),
|
||||
ConfigSubentryData(
|
||||
data=TEST_OPTIONS,
|
||||
data=V21_TEST_OPTIONS,
|
||||
subentry_id="mock_id_2",
|
||||
subentry_type="conversation",
|
||||
title="Ollama 2",
|
||||
@ -392,8 +414,8 @@ async def test_migration_from_v2_1_to_v2_2(
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 2
|
||||
assert entry.version == 3
|
||||
assert entry.minor_version == 1
|
||||
assert not entry.options
|
||||
assert entry.title == "Ollama"
|
||||
assert len(entry.subentries) == 2
|
||||
@ -405,6 +427,7 @@ async def test_migration_from_v2_1_to_v2_2(
|
||||
assert len(conversation_subentries) == 2
|
||||
for subentry in conversation_subentries:
|
||||
assert subentry.subentry_type == "conversation"
|
||||
# Since TEST_USER_DATA no longer has a model, subentry data should be TEST_OPTIONS
|
||||
assert subentry.data == TEST_OPTIONS
|
||||
assert "Ollama" in subentry.title
|
||||
|
||||
@ -450,3 +473,45 @@ async def test_migration_from_v2_1_to_v2_2(
|
||||
assert device.config_entries_subentries == {
|
||||
mock_config_entry.entry_id: {subentry.subentry_id}
|
||||
}
|
||||
|
||||
|
||||
async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
|
||||
"""Test migration from version 2.2."""
|
||||
subentry_data = ConfigSubentryData(
|
||||
data=V21_TEST_USER_DATA,
|
||||
subentry_type="conversation",
|
||||
title="Test Conversation",
|
||||
unique_id=None,
|
||||
)
|
||||
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: "test_model:latest", # Model still in main data
|
||||
},
|
||||
version=2,
|
||||
minor_version=2,
|
||||
subentries_data=[subentry_data],
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.ollama.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
|
||||
# Check migration to v3.1
|
||||
assert mock_config_entry.version == 3
|
||||
assert mock_config_entry.minor_version == 1
|
||||
|
||||
# Check that model was moved from main data to subentry
|
||||
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
assert len(mock_config_entry.subentries) == 1
|
||||
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
assert subentry.data == {
|
||||
**V21_TEST_USER_DATA,
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user