mirror of
https://github.com/home-assistant/core.git
synced 2025-07-09 14:27:07 +00:00
Ollama: Migrate pick model to subentry (#147944)
This commit is contained in:
parent
943fb9948b
commit
f50ef79c72
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import ollama
|
import ollama
|
||||||
@ -100,8 +101,12 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
use_existing = False
|
use_existing = False
|
||||||
|
# Create subentry with model from entry.data and options from entry.options
|
||||||
|
subentry_data = entry.options.copy()
|
||||||
|
subentry_data[CONF_MODEL] = entry.data[CONF_MODEL]
|
||||||
|
|
||||||
subentry = ConfigSubentry(
|
subentry = ConfigSubentry(
|
||||||
data=entry.options,
|
data=MappingProxyType(subentry_data),
|
||||||
subentry_type="conversation",
|
subentry_type="conversation",
|
||||||
title=entry.title,
|
title=entry.title,
|
||||||
unique_id=None,
|
unique_id=None,
|
||||||
@ -154,9 +159,11 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
|||||||
hass.config_entries.async_update_entry(
|
hass.config_entries.async_update_entry(
|
||||||
entry,
|
entry,
|
||||||
title=DEFAULT_NAME,
|
title=DEFAULT_NAME,
|
||||||
|
# Update parent entry to only keep URL, remove model
|
||||||
|
data={CONF_URL: entry.data[CONF_URL]},
|
||||||
options={},
|
options={},
|
||||||
version=2,
|
version=3,
|
||||||
minor_version=2,
|
minor_version=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -164,7 +171,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
|
|||||||
"""Migrate entry."""
|
"""Migrate entry."""
|
||||||
_LOGGER.debug("Migrating from version %s:%s", entry.version, entry.minor_version)
|
_LOGGER.debug("Migrating from version %s:%s", entry.version, entry.minor_version)
|
||||||
|
|
||||||
if entry.version > 2:
|
if entry.version > 3:
|
||||||
# This means the user has downgraded from a future version
|
# This means the user has downgraded from a future version
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -182,6 +189,25 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
|
|||||||
|
|
||||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||||
|
|
||||||
|
if entry.version == 2 and entry.minor_version == 2:
|
||||||
|
# Update subentries to include the model
|
||||||
|
for subentry in entry.subentries.values():
|
||||||
|
if subentry.subentry_type == "conversation":
|
||||||
|
updated_data = dict(subentry.data)
|
||||||
|
updated_data[CONF_MODEL] = entry.data[CONF_MODEL]
|
||||||
|
|
||||||
|
hass.config_entries.async_update_subentry(
|
||||||
|
entry, subentry, data=MappingProxyType(updated_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update main entry to remove model and bump version
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
entry,
|
||||||
|
data={CONF_URL: entry.data[CONF_URL]},
|
||||||
|
version=3,
|
||||||
|
minor_version=1,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,7 @@ from homeassistant.config_entries import (
|
|||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
|
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import config_validation as cv, llm
|
||||||
from homeassistant.helpers.selector import (
|
from homeassistant.helpers.selector import (
|
||||||
BooleanSelector,
|
BooleanSelector,
|
||||||
NumberSelector,
|
NumberSelector,
|
||||||
@ -38,6 +38,7 @@ from homeassistant.helpers.selector import (
|
|||||||
)
|
)
|
||||||
from homeassistant.util.ssl import get_default_context
|
from homeassistant.util.ssl import get_default_context
|
||||||
|
|
||||||
|
from . import OllamaConfigEntry
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_KEEP_ALIVE,
|
CONF_KEEP_ALIVE,
|
||||||
CONF_MAX_HISTORY,
|
CONF_MAX_HISTORY,
|
||||||
@ -72,43 +73,43 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||||||
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
"""Handle a config flow for Ollama."""
|
"""Handle a config flow for Ollama."""
|
||||||
|
|
||||||
VERSION = 2
|
VERSION = 3
|
||||||
MINOR_VERSION = 2
|
MINOR_VERSION = 1
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize config flow."""
|
"""Initialize config flow."""
|
||||||
self.url: str | None = None
|
self.url: str | None = None
|
||||||
self.model: str | None = None
|
|
||||||
self.client: ollama.AsyncClient | None = None
|
|
||||||
self.download_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Handle the initial step."""
|
"""Handle the initial step."""
|
||||||
user_input = user_input or {}
|
if user_input is None:
|
||||||
self.url = user_input.get(CONF_URL, self.url)
|
|
||||||
self.model = user_input.get(CONF_MODEL, self.model)
|
|
||||||
|
|
||||||
if self.url is None:
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False
|
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
||||||
)
|
)
|
||||||
|
|
||||||
errors = {}
|
errors = {}
|
||||||
|
url = user_input[CONF_URL]
|
||||||
|
|
||||||
self._async_abort_entries_match({CONF_URL: self.url})
|
self._async_abort_entries_match({CONF_URL: url})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client = ollama.AsyncClient(
|
url = cv.url(url)
|
||||||
host=self.url, verify=get_default_context()
|
except vol.Invalid:
|
||||||
|
errors["base"] = "invalid_url"
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="user",
|
||||||
|
data_schema=self.add_suggested_values_to_schema(
|
||||||
|
STEP_USER_DATA_SCHEMA, user_input
|
||||||
|
),
|
||||||
|
errors=errors,
|
||||||
)
|
)
|
||||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
|
||||||
response = await self.client.list()
|
|
||||||
|
|
||||||
downloaded_models: set[str] = {
|
try:
|
||||||
model_info["model"] for model_info in response.get("models", [])
|
client = ollama.AsyncClient(host=url, verify=get_default_context())
|
||||||
}
|
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||||
|
await client.list()
|
||||||
except (TimeoutError, httpx.ConnectError):
|
except (TimeoutError, httpx.ConnectError):
|
||||||
errors["base"] = "cannot_connect"
|
errors["base"] = "cannot_connect"
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -117,10 +118,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
step_id="user",
|
||||||
|
data_schema=self.add_suggested_values_to_schema(
|
||||||
|
STEP_USER_DATA_SCHEMA, user_input
|
||||||
|
),
|
||||||
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model is None:
|
return self.async_create_entry(
|
||||||
|
title=url,
|
||||||
|
data={CONF_URL: url},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@callback
|
||||||
|
def async_get_supported_subentry_types(
|
||||||
|
cls, config_entry: ConfigEntry
|
||||||
|
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||||
|
"""Return subentries supported by this integration."""
|
||||||
|
return {"conversation": ConversationSubentryFlowHandler}
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||||
|
"""Flow for managing conversation subentries."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the subentry flow."""
|
||||||
|
super().__init__()
|
||||||
|
self._name: str | None = None
|
||||||
|
self._model: str | None = None
|
||||||
|
self.download_task: asyncio.Task | None = None
|
||||||
|
self._config_data: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _is_new(self) -> bool:
|
||||||
|
"""Return if this is a new subentry."""
|
||||||
|
return self.source == "user"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _client(self) -> ollama.AsyncClient:
|
||||||
|
"""Return the Ollama client."""
|
||||||
|
entry: OllamaConfigEntry = self._get_entry()
|
||||||
|
return entry.runtime_data
|
||||||
|
|
||||||
|
async def async_step_set_options(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> SubentryFlowResult:
|
||||||
|
"""Handle model selection and configuration step."""
|
||||||
|
if self._get_entry().state != ConfigEntryState.LOADED:
|
||||||
|
return self.async_abort(reason="entry_not_loaded")
|
||||||
|
|
||||||
|
if user_input is None:
|
||||||
|
# Get available models from Ollama server
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||||
|
response = await self._client.list()
|
||||||
|
|
||||||
|
downloaded_models: set[str] = {
|
||||||
|
model_info["model"] for model_info in response.get("models", [])
|
||||||
|
}
|
||||||
|
except (TimeoutError, httpx.ConnectError, httpx.HTTPError):
|
||||||
|
_LOGGER.exception("Failed to get models from Ollama server")
|
||||||
|
return self.async_abort(reason="cannot_connect")
|
||||||
|
|
||||||
# Show models that have been downloaded first, followed by all known
|
# Show models that have been downloaded first, followed by all known
|
||||||
# models (only latest tags).
|
# models (only latest tags).
|
||||||
models_to_list = [
|
models_to_list = [
|
||||||
@ -131,52 +191,69 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
for m in sorted(MODEL_NAMES)
|
for m in sorted(MODEL_NAMES)
|
||||||
if m not in downloaded_models
|
if m not in downloaded_models
|
||||||
]
|
]
|
||||||
model_step_schema = vol.Schema(
|
|
||||||
{
|
if self._is_new:
|
||||||
vol.Required(
|
options = {}
|
||||||
CONF_MODEL, description={"suggested_value": DEFAULT_MODEL}
|
else:
|
||||||
): SelectSelector(
|
options = self._get_reconfigure_subentry().data.copy()
|
||||||
SelectSelectorConfig(options=models_to_list, custom_value=True)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user",
|
step_id="set_options",
|
||||||
data_schema=model_step_schema,
|
data_schema=vol.Schema(
|
||||||
|
ollama_config_option_schema(
|
||||||
|
self.hass, self._is_new, options, models_to_list
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model not in downloaded_models:
|
self._model = user_input[CONF_MODEL]
|
||||||
# Ollama server needs to download model first
|
if self._is_new:
|
||||||
return await self.async_step_download()
|
self._name = user_input.pop(CONF_NAME)
|
||||||
|
|
||||||
return self.async_create_entry(
|
# Check if model needs to be downloaded
|
||||||
title=self.url,
|
try:
|
||||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||||
subentries=[
|
response = await self._client.list()
|
||||||
{
|
|
||||||
"subentry_type": "conversation",
|
currently_downloaded_models: set[str] = {
|
||||||
"data": {},
|
model_info["model"] for model_info in response.get("models", [])
|
||||||
"title": _get_title(self.model),
|
}
|
||||||
"unique_id": None,
|
|
||||||
}
|
if self._model not in currently_downloaded_models:
|
||||||
],
|
# Store the user input to use after download
|
||||||
|
self._config_data = user_input
|
||||||
|
# Ollama server needs to download model first
|
||||||
|
return await self.async_step_download()
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception("Failed to check model availability")
|
||||||
|
return self.async_abort(reason="cannot_connect")
|
||||||
|
|
||||||
|
# Model is already downloaded, create/update the entry
|
||||||
|
if self._is_new:
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=self._name,
|
||||||
|
data=user_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_update_and_abort(
|
||||||
|
self._get_entry(),
|
||||||
|
self._get_reconfigure_subentry(),
|
||||||
|
data=user_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_download(
|
async def async_step_download(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""Step to wait for Ollama server to download a model."""
|
"""Step to wait for Ollama server to download a model."""
|
||||||
assert self.model is not None
|
assert self._model is not None
|
||||||
assert self.client is not None
|
|
||||||
|
|
||||||
if self.download_task is None:
|
if self.download_task is None:
|
||||||
# Tell Ollama server to pull the model.
|
# Tell Ollama server to pull the model.
|
||||||
# The task will block until the model and metadata are fully
|
# The task will block until the model and metadata are fully
|
||||||
# downloaded.
|
# downloaded.
|
||||||
self.download_task = self.hass.async_create_background_task(
|
self.download_task = self.hass.async_create_background_task(
|
||||||
self.client.pull(self.model),
|
self._client.pull(self._model),
|
||||||
f"Downloading {self.model}",
|
f"Downloading {self._model}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.download_task.done():
|
if self.download_task.done():
|
||||||
@ -192,80 +269,28 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
progress_task=self.download_task,
|
progress_task=self.download_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_finish(
|
|
||||||
self, user_input: dict[str, Any] | None = None
|
|
||||||
) -> ConfigFlowResult:
|
|
||||||
"""Step after model downloading has succeeded."""
|
|
||||||
assert self.url is not None
|
|
||||||
assert self.model is not None
|
|
||||||
|
|
||||||
return self.async_create_entry(
|
|
||||||
title=_get_title(self.model),
|
|
||||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
|
||||||
subentries=[
|
|
||||||
{
|
|
||||||
"subentry_type": "conversation",
|
|
||||||
"data": {},
|
|
||||||
"title": _get_title(self.model),
|
|
||||||
"unique_id": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_step_failed(
|
async def async_step_failed(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""Step after model downloading has failed."""
|
"""Step after model downloading has failed."""
|
||||||
return self.async_abort(reason="download_failed")
|
return self.async_abort(reason="download_failed")
|
||||||
|
|
||||||
@classmethod
|
async def async_step_finish(
|
||||||
@callback
|
|
||||||
def async_get_supported_subentry_types(
|
|
||||||
cls, config_entry: ConfigEntry
|
|
||||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
|
||||||
"""Return subentries supported by this integration."""
|
|
||||||
return {"conversation": ConversationSubentryFlowHandler}
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|
||||||
"""Flow for managing conversation subentries."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _is_new(self) -> bool:
|
|
||||||
"""Return if this is a new subentry."""
|
|
||||||
return self.source == "user"
|
|
||||||
|
|
||||||
async def async_step_set_options(
|
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> SubentryFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""Set conversation options."""
|
"""Step after model downloading has succeeded."""
|
||||||
# abort if entry is not loaded
|
assert self._config_data is not None
|
||||||
if self._get_entry().state != ConfigEntryState.LOADED:
|
|
||||||
return self.async_abort(reason="entry_not_loaded")
|
|
||||||
|
|
||||||
errors: dict[str, str] = {}
|
# Model download completed, create/update the entry with stored config
|
||||||
|
if self._is_new:
|
||||||
if user_input is None:
|
|
||||||
if self._is_new:
|
|
||||||
options = {}
|
|
||||||
else:
|
|
||||||
options = self._get_reconfigure_subentry().data.copy()
|
|
||||||
|
|
||||||
elif self._is_new:
|
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=user_input.pop(CONF_NAME),
|
title=self._name,
|
||||||
data=user_input,
|
data=self._config_data,
|
||||||
)
|
)
|
||||||
else:
|
return self.async_update_and_abort(
|
||||||
return self.async_update_and_abort(
|
self._get_entry(),
|
||||||
self._get_entry(),
|
self._get_reconfigure_subentry(),
|
||||||
self._get_reconfigure_subentry(),
|
data=self._config_data,
|
||||||
data=user_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = ollama_config_option_schema(self.hass, self._is_new, options)
|
|
||||||
return self.async_show_form(
|
|
||||||
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async_step_user = async_step_set_options
|
async_step_user = async_step_set_options
|
||||||
@ -273,19 +298,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
|
|
||||||
|
|
||||||
def ollama_config_option_schema(
|
def ollama_config_option_schema(
|
||||||
hass: HomeAssistant, is_new: bool, options: Mapping[str, Any]
|
hass: HomeAssistant,
|
||||||
|
is_new: bool,
|
||||||
|
options: Mapping[str, Any],
|
||||||
|
models_to_list: list[SelectOptionDict],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Ollama options schema."""
|
"""Ollama options schema."""
|
||||||
hass_apis: list[SelectOptionDict] = [
|
|
||||||
SelectOptionDict(
|
|
||||||
label=api.name,
|
|
||||||
value=api.id,
|
|
||||||
)
|
|
||||||
for api in llm.async_get_apis(hass)
|
|
||||||
]
|
|
||||||
|
|
||||||
if is_new:
|
if is_new:
|
||||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
schema: dict = {
|
||||||
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
|
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@ -293,6 +313,12 @@ def ollama_config_option_schema(
|
|||||||
|
|
||||||
schema.update(
|
schema.update(
|
||||||
{
|
{
|
||||||
|
vol.Required(
|
||||||
|
CONF_MODEL,
|
||||||
|
description={"suggested_value": options.get(CONF_MODEL, DEFAULT_MODEL)},
|
||||||
|
): SelectSelector(
|
||||||
|
SelectSelectorConfig(options=models_to_list, custom_value=True)
|
||||||
|
),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
description={
|
description={
|
||||||
@ -304,7 +330,18 @@ def ollama_config_option_schema(
|
|||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_LLM_HASS_API,
|
CONF_LLM_HASS_API,
|
||||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||||
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
): SelectSelector(
|
||||||
|
SelectSelectorConfig(
|
||||||
|
options=[
|
||||||
|
SelectOptionDict(
|
||||||
|
label=api.name,
|
||||||
|
value=api.id,
|
||||||
|
)
|
||||||
|
for api in llm.async_get_apis(hass)
|
||||||
|
],
|
||||||
|
multiple=True,
|
||||||
|
)
|
||||||
|
),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_NUM_CTX,
|
CONF_NUM_CTX,
|
||||||
description={
|
description={
|
||||||
@ -350,11 +387,3 @@ def ollama_config_option_schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def _get_title(model: str) -> str:
|
|
||||||
"""Get title for config entry."""
|
|
||||||
if model.endswith(":latest"):
|
|
||||||
model = model.split(":", maxsplit=1)[0]
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
@ -166,11 +166,14 @@ class OllamaBaseLLMEntity(Entity):
|
|||||||
self.subentry = subentry
|
self.subentry = subentry
|
||||||
self._attr_name = subentry.title
|
self._attr_name = subentry.title
|
||||||
self._attr_unique_id = subentry.subentry_id
|
self._attr_unique_id = subentry.subentry_id
|
||||||
|
|
||||||
|
model, _, version = subentry.data[CONF_MODEL].partition(":")
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
identifiers={(DOMAIN, subentry.subentry_id)},
|
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||||
name=subentry.title,
|
name=subentry.title,
|
||||||
manufacturer="Ollama",
|
manufacturer="Ollama",
|
||||||
model=entry.data[CONF_MODEL],
|
model=model,
|
||||||
|
sw_version=version or "latest",
|
||||||
entry_type=dr.DeviceEntryType.SERVICE,
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,24 +3,17 @@
|
|||||||
"step": {
|
"step": {
|
||||||
"user": {
|
"user": {
|
||||||
"data": {
|
"data": {
|
||||||
"url": "[%key:common::config_flow::data::url%]",
|
"url": "[%key:common::config_flow::data::url%]"
|
||||||
"model": "Model"
|
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"download": {
|
|
||||||
"title": "Downloading model"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
"download_failed": "Model downloading failed",
|
|
||||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
|
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
|
||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
|
"invalid_url": "[%key:common::config_flow::error::invalid_host%]",
|
||||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||||
},
|
|
||||||
"progress": {
|
|
||||||
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"config_subentries": {
|
"config_subentries": {
|
||||||
@ -33,6 +26,7 @@
|
|||||||
"step": {
|
"step": {
|
||||||
"set_options": {
|
"set_options": {
|
||||||
"data": {
|
"data": {
|
||||||
|
"model": "Model",
|
||||||
"name": "[%key:common::config_flow::data::name%]",
|
"name": "[%key:common::config_flow::data::name%]",
|
||||||
"prompt": "Instructions",
|
"prompt": "Instructions",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||||
@ -47,11 +41,19 @@
|
|||||||
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
|
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
|
||||||
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
|
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"download": {
|
||||||
|
"title": "Downloading model"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
||||||
"entry_not_loaded": "Cannot add things while the configuration is disabled."
|
"entry_not_loaded": "Failed to add agent. The configuration is disabled.",
|
||||||
|
"download_failed": "Model downloading failed",
|
||||||
|
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
|
||||||
|
},
|
||||||
|
"progress": {
|
||||||
|
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,10 +5,10 @@ from homeassistant.helpers import llm
|
|||||||
|
|
||||||
TEST_USER_DATA = {
|
TEST_USER_DATA = {
|
||||||
ollama.CONF_URL: "http://localhost:11434",
|
ollama.CONF_URL: "http://localhost:11434",
|
||||||
ollama.CONF_MODEL: "test model",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_OPTIONS = {
|
TEST_OPTIONS = {
|
||||||
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||||
ollama.CONF_MAX_HISTORY: 2,
|
ollama.CONF_MAX_HISTORY: 2,
|
||||||
|
ollama.CONF_MODEL: "test_model:latest",
|
||||||
}
|
}
|
||||||
|
@ -30,10 +30,11 @@ def mock_config_entry(
|
|||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain=ollama.DOMAIN,
|
domain=ollama.DOMAIN,
|
||||||
data=TEST_USER_DATA,
|
data=TEST_USER_DATA,
|
||||||
version=2,
|
version=3,
|
||||||
|
minor_version=1,
|
||||||
subentries_data=[
|
subentries_data=[
|
||||||
{
|
{
|
||||||
"data": mock_config_entry_options,
|
"data": {**TEST_OPTIONS, **mock_config_entry_options},
|
||||||
"subentry_type": "conversation",
|
"subentry_type": "conversation",
|
||||||
"title": "Ollama Conversation",
|
"title": "Ollama Conversation",
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
@ -49,10 +50,14 @@ def mock_config_entry_with_assist(
|
|||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
"""Mock a config entry with assist."""
|
"""Mock a config entry with assist."""
|
||||||
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
hass.config_entries.async_update_subentry(
|
hass.config_entries.async_update_subentry(
|
||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
next(iter(mock_config_entry.subentries.values())),
|
subentry,
|
||||||
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
data={
|
||||||
|
**subentry.data,
|
||||||
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return mock_config_entry
|
return mock_config_entry
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import pytest
|
|||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components import ollama
|
from homeassistant.components import ollama
|
||||||
|
from homeassistant.const import CONF_NAME
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ TEST_MODEL = "test_model:latest"
|
|||||||
|
|
||||||
|
|
||||||
async def test_form(hass: HomeAssistant) -> None:
|
async def test_form(hass: HomeAssistant) -> None:
|
||||||
"""Test flow when the model is already downloaded."""
|
"""Test flow when configuring URL only."""
|
||||||
# Pretend we already set up a config entry.
|
# Pretend we already set up a config entry.
|
||||||
hass.config.components.add(ollama.DOMAIN)
|
hass.config.components.add(ollama.DOMAIN)
|
||||||
MockConfigEntry(
|
MockConfigEntry(
|
||||||
@ -34,7 +35,6 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||||
# test model is already "downloaded"
|
|
||||||
return_value={"models": [{"model": TEST_MODEL}]},
|
return_value={"models": [{"model": TEST_MODEL}]},
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@ -42,24 +42,17 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
return_value=True,
|
return_value=True,
|
||||||
) as mock_setup_entry,
|
) as mock_setup_entry,
|
||||||
):
|
):
|
||||||
# Step 1: URL
|
|
||||||
result2 = await hass.config_entries.flow.async_configure(
|
result2 = await hass.config_entries.flow.async_configure(
|
||||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
# Step 2: model
|
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert result2["type"] is FlowResultType.FORM
|
assert result2["data"] == {
|
||||||
result3 = await hass.config_entries.flow.async_configure(
|
|
||||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert result3["type"] is FlowResultType.CREATE_ENTRY
|
|
||||||
assert result3["data"] == {
|
|
||||||
ollama.CONF_URL: "http://localhost:11434",
|
ollama.CONF_URL: "http://localhost:11434",
|
||||||
ollama.CONF_MODEL: TEST_MODEL,
|
|
||||||
}
|
}
|
||||||
|
# No subentries created by default
|
||||||
|
assert len(result2.get("subentries", [])) == 0
|
||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
@ -94,98 +87,6 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
|
|||||||
assert result["reason"] == "already_configured"
|
assert result["reason"] == "already_configured"
|
||||||
|
|
||||||
|
|
||||||
async def test_form_need_download(hass: HomeAssistant) -> None:
|
|
||||||
"""Test flow when a model needs to be downloaded."""
|
|
||||||
# Pretend we already set up a config entry.
|
|
||||||
hass.config.components.add(ollama.DOMAIN)
|
|
||||||
MockConfigEntry(
|
|
||||||
domain=ollama.DOMAIN,
|
|
||||||
state=config_entries.ConfigEntryState.LOADED,
|
|
||||||
).add_to_hass(hass)
|
|
||||||
|
|
||||||
result = await hass.config_entries.flow.async_init(
|
|
||||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
|
||||||
)
|
|
||||||
assert result["type"] is FlowResultType.FORM
|
|
||||||
assert result["errors"] is None
|
|
||||||
|
|
||||||
pull_ready = asyncio.Event()
|
|
||||||
pull_called = asyncio.Event()
|
|
||||||
pull_model: str | None = None
|
|
||||||
|
|
||||||
async def pull(self, model: str, *args, **kwargs) -> None:
|
|
||||||
nonlocal pull_model
|
|
||||||
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
await pull_ready.wait()
|
|
||||||
|
|
||||||
pull_model = model
|
|
||||||
pull_called.set()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
|
||||||
# No models are downloaded
|
|
||||||
return_value={},
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
|
|
||||||
pull,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.ollama.async_setup_entry",
|
|
||||||
return_value=True,
|
|
||||||
) as mock_setup_entry,
|
|
||||||
):
|
|
||||||
# Step 1: URL
|
|
||||||
result2 = await hass.config_entries.flow.async_configure(
|
|
||||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
# Step 2: model
|
|
||||||
assert result2["type"] is FlowResultType.FORM
|
|
||||||
result3 = await hass.config_entries.flow.async_configure(
|
|
||||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
# Step 3: download
|
|
||||||
assert result3["type"] is FlowResultType.SHOW_PROGRESS
|
|
||||||
result4 = await hass.config_entries.flow.async_configure(
|
|
||||||
result3["flow_id"],
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
# Run again without the task finishing.
|
|
||||||
# We should still be downloading.
|
|
||||||
assert result4["type"] is FlowResultType.SHOW_PROGRESS
|
|
||||||
result4 = await hass.config_entries.flow.async_configure(
|
|
||||||
result4["flow_id"],
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert result4["type"] is FlowResultType.SHOW_PROGRESS
|
|
||||||
|
|
||||||
# Signal fake pull method to complete
|
|
||||||
pull_ready.set()
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
await pull_called.wait()
|
|
||||||
|
|
||||||
assert pull_model == TEST_MODEL
|
|
||||||
|
|
||||||
# Step 4: finish
|
|
||||||
result5 = await hass.config_entries.flow.async_configure(
|
|
||||||
result4["flow_id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result5["type"] is FlowResultType.CREATE_ENTRY
|
|
||||||
assert result5["data"] == {
|
|
||||||
ollama.CONF_URL: "http://localhost:11434",
|
|
||||||
ollama.CONF_MODEL: TEST_MODEL,
|
|
||||||
}
|
|
||||||
assert len(mock_setup_entry.mock_calls) == 1
|
|
||||||
|
|
||||||
|
|
||||||
async def test_subentry_options(
|
async def test_subentry_options(
|
||||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -193,34 +94,84 @@ async def test_subentry_options(
|
|||||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
|
||||||
# Test reconfiguration
|
# Test reconfiguration
|
||||||
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
with patch(
|
||||||
hass, subentry.subentry_id
|
"ollama.AsyncClient.list",
|
||||||
)
|
return_value={"models": [{"model": TEST_MODEL}]},
|
||||||
|
):
|
||||||
|
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||||
|
hass, subentry.subentry_id
|
||||||
|
)
|
||||||
|
|
||||||
assert options_flow["type"] is FlowResultType.FORM
|
assert options_flow["type"] is FlowResultType.FORM
|
||||||
assert options_flow["step_id"] == "set_options"
|
assert options_flow["step_id"] == "set_options"
|
||||||
|
|
||||||
options = await hass.config_entries.subentries.async_configure(
|
options = await hass.config_entries.subentries.async_configure(
|
||||||
options_flow["flow_id"],
|
options_flow["flow_id"],
|
||||||
{
|
{
|
||||||
ollama.CONF_PROMPT: "test prompt",
|
ollama.CONF_MODEL: TEST_MODEL,
|
||||||
ollama.CONF_MAX_HISTORY: 100,
|
ollama.CONF_PROMPT: "test prompt",
|
||||||
ollama.CONF_NUM_CTX: 32768,
|
ollama.CONF_MAX_HISTORY: 100,
|
||||||
ollama.CONF_THINK: True,
|
ollama.CONF_NUM_CTX: 32768,
|
||||||
},
|
ollama.CONF_THINK: True,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert options["type"] is FlowResultType.ABORT
|
assert options["type"] is FlowResultType.ABORT
|
||||||
assert options["reason"] == "reconfigure_successful"
|
assert options["reason"] == "reconfigure_successful"
|
||||||
assert subentry.data == {
|
assert subentry.data == {
|
||||||
|
ollama.CONF_MODEL: TEST_MODEL,
|
||||||
ollama.CONF_PROMPT: "test prompt",
|
ollama.CONF_PROMPT: "test prompt",
|
||||||
ollama.CONF_MAX_HISTORY: 100,
|
ollama.CONF_MAX_HISTORY: 100.0,
|
||||||
ollama.CONF_NUM_CTX: 32768,
|
ollama.CONF_NUM_CTX: 32768.0,
|
||||||
ollama.CONF_THINK: True,
|
ollama.CONF_THINK: True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_creating_new_conversation_subentry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a new conversation subentry includes name field."""
|
||||||
|
# Start a new subentry flow
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.list",
|
||||||
|
return_value={"models": [{"model": TEST_MODEL}]},
|
||||||
|
):
|
||||||
|
new_flow = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "conversation"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_flow["type"] is FlowResultType.FORM
|
||||||
|
assert new_flow["step_id"] == "set_options"
|
||||||
|
|
||||||
|
# Configure the new subentry with name field
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
new_flow["flow_id"],
|
||||||
|
{
|
||||||
|
ollama.CONF_MODEL: TEST_MODEL,
|
||||||
|
CONF_NAME: "New Test Conversation",
|
||||||
|
ollama.CONF_PROMPT: "new test prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 50,
|
||||||
|
ollama.CONF_NUM_CTX: 16384,
|
||||||
|
ollama.CONF_THINK: False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result["title"] == "New Test Conversation"
|
||||||
|
assert result["data"] == {
|
||||||
|
ollama.CONF_MODEL: TEST_MODEL,
|
||||||
|
ollama.CONF_PROMPT: "new test prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 50.0,
|
||||||
|
ollama.CONF_NUM_CTX: 16384.0,
|
||||||
|
ollama.CONF_THINK: False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_creating_conversation_subentry_not_loaded(
|
async def test_creating_conversation_subentry_not_loaded(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
@ -237,6 +188,125 @@ async def test_creating_conversation_subentry_not_loaded(
|
|||||||
assert result["reason"] == "entry_not_loaded"
|
assert result["reason"] == "entry_not_loaded"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subentry_need_download(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_init_component,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test subentry creation when model needs to be downloaded."""
|
||||||
|
|
||||||
|
async def delayed_pull(self, model: str) -> None:
|
||||||
|
"""Simulate a delayed model download."""
|
||||||
|
assert model == "llama3.2:latest"
|
||||||
|
await asyncio.sleep(0) # yield the event loop 1 iteration
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"ollama.AsyncClient.list",
|
||||||
|
return_value={"models": [{"model": TEST_MODEL}]},
|
||||||
|
),
|
||||||
|
patch("ollama.AsyncClient.pull", delayed_pull),
|
||||||
|
):
|
||||||
|
new_flow = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "conversation"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_flow["type"] is FlowResultType.FORM, new_flow
|
||||||
|
assert new_flow["step_id"] == "set_options"
|
||||||
|
|
||||||
|
# Configure the new subentry with a model that needs downloading
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
new_flow["flow_id"],
|
||||||
|
{
|
||||||
|
ollama.CONF_MODEL: "llama3.2:latest", # not cached
|
||||||
|
CONF_NAME: "New Test Conversation",
|
||||||
|
ollama.CONF_PROMPT: "new test prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 50,
|
||||||
|
ollama.CONF_NUM_CTX: 16384,
|
||||||
|
ollama.CONF_THINK: False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.SHOW_PROGRESS
|
||||||
|
assert result["step_id"] == "download"
|
||||||
|
assert result["progress_action"] == "download"
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
new_flow["flow_id"], {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result["title"] == "New Test Conversation"
|
||||||
|
assert result["data"] == {
|
||||||
|
ollama.CONF_MODEL: "llama3.2:latest",
|
||||||
|
ollama.CONF_PROMPT: "new test prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 50.0,
|
||||||
|
ollama.CONF_NUM_CTX: 16384.0,
|
||||||
|
ollama.CONF_THINK: False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subentry_download_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_init_component,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test subentry creation when model download fails."""
|
||||||
|
|
||||||
|
async def delayed_pull(self, model: str) -> None:
|
||||||
|
"""Simulate a delayed model download."""
|
||||||
|
await asyncio.sleep(0) # yield
|
||||||
|
|
||||||
|
raise RuntimeError("Download failed")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"ollama.AsyncClient.list",
|
||||||
|
return_value={"models": [{"model": TEST_MODEL}]},
|
||||||
|
),
|
||||||
|
patch("ollama.AsyncClient.pull", delayed_pull),
|
||||||
|
):
|
||||||
|
new_flow = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "conversation"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_flow["type"] is FlowResultType.FORM
|
||||||
|
assert new_flow["step_id"] == "set_options"
|
||||||
|
|
||||||
|
# Configure with a model that needs downloading but will fail
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
new_flow["flow_id"],
|
||||||
|
{
|
||||||
|
ollama.CONF_MODEL: "llama3.2:latest",
|
||||||
|
CONF_NAME: "New Test Conversation",
|
||||||
|
ollama.CONF_PROMPT: "new test prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 50,
|
||||||
|
ollama.CONF_NUM_CTX: 16384,
|
||||||
|
ollama.CONF_THINK: False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should show progress flow result for download
|
||||||
|
assert result["type"] is FlowResultType.SHOW_PROGRESS
|
||||||
|
assert result["step_id"] == "download"
|
||||||
|
assert result["progress_action"] == "download"
|
||||||
|
|
||||||
|
# Wait for download task to complete (with error)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Submit the progress flow - should get failure
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
new_flow["flow_id"], {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "download_failed"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
[
|
[
|
||||||
@ -262,40 +332,132 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
|
|||||||
assert result2["errors"] == {"base": error}
|
assert result2["errors"] == {"base": error}
|
||||||
|
|
||||||
|
|
||||||
async def test_download_error(hass: HomeAssistant) -> None:
|
async def test_form_invalid_url(hass: HomeAssistant) -> None:
|
||||||
"""Test we handle errors while downloading a model."""
|
"""Test we handle invalid URL."""
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _delayed_runtime_error(*args, **kwargs):
|
result2 = await hass.config_entries.flow.async_configure(
|
||||||
await asyncio.sleep(0)
|
result["flow_id"], {ollama.CONF_URL: "not-a-valid-url"}
|
||||||
raise RuntimeError
|
)
|
||||||
|
|
||||||
|
assert result2["type"] is FlowResultType.FORM
|
||||||
|
assert result2["errors"] == {"base": "invalid_url"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subentry_connection_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_init_component,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test subentry creation when connection to Ollama server fails."""
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.list",
|
||||||
|
side_effect=ConnectError("Connection failed"),
|
||||||
|
):
|
||||||
|
new_flow = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "conversation"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_flow["type"] is FlowResultType.ABORT
|
||||||
|
assert new_flow["reason"] == "cannot_connect"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subentry_model_check_exception(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_init_component,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test subentry creation when checking model availability throws exception."""
|
||||||
|
with patch(
|
||||||
|
"ollama.AsyncClient.list",
|
||||||
|
side_effect=[
|
||||||
|
{"models": [{"model": TEST_MODEL}]}, # First call succeeds
|
||||||
|
RuntimeError("Failed to check models"), # Second call fails
|
||||||
|
],
|
||||||
|
):
|
||||||
|
new_flow = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "conversation"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_flow["type"] is FlowResultType.FORM
|
||||||
|
assert new_flow["step_id"] == "set_options"
|
||||||
|
|
||||||
|
# Configure with a model, should fail when checking availability
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
new_flow["flow_id"],
|
||||||
|
{
|
||||||
|
ollama.CONF_MODEL: "new_model:latest",
|
||||||
|
CONF_NAME: "Test Conversation",
|
||||||
|
ollama.CONF_PROMPT: "test prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 50,
|
||||||
|
ollama.CONF_NUM_CTX: 16384,
|
||||||
|
ollama.CONF_THINK: False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "cannot_connect"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subentry_reconfigure_with_download(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_init_component,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test reconfiguring subentry when model needs to be downloaded."""
|
||||||
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
|
||||||
|
async def delayed_pull(self, model: str) -> None:
|
||||||
|
"""Simulate a delayed model download."""
|
||||||
|
assert model == "llama3.2:latest"
|
||||||
|
await asyncio.sleep(0) # yield the event loop
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
"ollama.AsyncClient.list",
|
||||||
return_value={},
|
return_value={"models": [{"model": TEST_MODEL}]},
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
|
|
||||||
_delayed_runtime_error,
|
|
||||||
),
|
),
|
||||||
|
patch("ollama.AsyncClient.pull", delayed_pull),
|
||||||
):
|
):
|
||||||
result2 = await hass.config_entries.flow.async_configure(
|
reconfigure_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
hass, subentry.subentry_id
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert result2["type"] is FlowResultType.FORM
|
assert reconfigure_flow["type"] is FlowResultType.FORM
|
||||||
result3 = await hass.config_entries.flow.async_configure(
|
assert reconfigure_flow["step_id"] == "set_options"
|
||||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
|
||||||
|
# Reconfigure with a model that needs downloading
|
||||||
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
|
reconfigure_flow["flow_id"],
|
||||||
|
{
|
||||||
|
ollama.CONF_MODEL: "llama3.2:latest",
|
||||||
|
ollama.CONF_PROMPT: "updated prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 75,
|
||||||
|
ollama.CONF_NUM_CTX: 8192,
|
||||||
|
ollama.CONF_THINK: True,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.SHOW_PROGRESS
|
||||||
|
assert result["step_id"] == "download"
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert result3["type"] is FlowResultType.SHOW_PROGRESS
|
# Finish download
|
||||||
result4 = await hass.config_entries.flow.async_configure(result3["flow_id"])
|
result = await hass.config_entries.subentries.async_configure(
|
||||||
await hass.async_block_till_done()
|
reconfigure_flow["flow_id"], {}
|
||||||
|
)
|
||||||
|
|
||||||
assert result4["type"] is FlowResultType.ABORT
|
assert result["type"] is FlowResultType.ABORT
|
||||||
assert result4["reason"] == "download_failed"
|
assert result["reason"] == "reconfigure_successful"
|
||||||
|
assert subentry.data == {
|
||||||
|
ollama.CONF_MODEL: "llama3.2:latest",
|
||||||
|
ollama.CONF_PROMPT: "updated prompt",
|
||||||
|
ollama.CONF_MAX_HISTORY: 75.0,
|
||||||
|
ollama.CONF_NUM_CTX: 8192.0,
|
||||||
|
ollama.CONF_THINK: True,
|
||||||
|
}
|
||||||
|
@ -15,7 +15,12 @@ from homeassistant.components.conversation import trace
|
|||||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import ATTR_SUPPORTED_FEATURES, CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent, llm
|
from homeassistant.helpers import (
|
||||||
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
|
intent,
|
||||||
|
llm,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -68,7 +73,7 @@ async def test_chat(
|
|||||||
args = mock_chat.call_args.kwargs
|
args = mock_chat.call_args.kwargs
|
||||||
prompt = args["messages"][0]["content"]
|
prompt = args["messages"][0]["content"]
|
||||||
|
|
||||||
assert args["model"] == "test model"
|
assert args["model"] == "test_model:latest"
|
||||||
assert args["messages"] == [
|
assert args["messages"] == [
|
||||||
Message(role="system", content=prompt),
|
Message(role="system", content=prompt),
|
||||||
Message(role="user", content="test message"),
|
Message(role="user", content="test message"),
|
||||||
@ -128,7 +133,7 @@ async def test_chat_stream(
|
|||||||
args = mock_chat.call_args.kwargs
|
args = mock_chat.call_args.kwargs
|
||||||
prompt = args["messages"][0]["content"]
|
prompt = args["messages"][0]["content"]
|
||||||
|
|
||||||
assert args["model"] == "test model"
|
assert args["model"] == "test_model:latest"
|
||||||
assert args["messages"] == [
|
assert args["messages"] == [
|
||||||
Message(role="system", content=prompt),
|
Message(role="system", content=prompt),
|
||||||
Message(role="user", content="test message"),
|
Message(role="user", content="test message"),
|
||||||
@ -158,6 +163,7 @@ async def test_template_variables(
|
|||||||
"The user name is {{ user_name }}. "
|
"The user name is {{ user_name }}. "
|
||||||
"The user id is {{ llm_context.context.user_id }}."
|
"The user id is {{ llm_context.context.user_id }}."
|
||||||
),
|
),
|
||||||
|
ollama.CONF_MODEL: "test_model:latest",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
@ -524,7 +530,9 @@ async def test_message_history_unlimited(
|
|||||||
):
|
):
|
||||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
hass.config_entries.async_update_subentry(
|
hass.config_entries.async_update_subentry(
|
||||||
mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0}
|
mock_config_entry,
|
||||||
|
subentry,
|
||||||
|
data={**subentry.data, ollama.CONF_MAX_HISTORY: 0},
|
||||||
)
|
)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
@ -573,6 +581,7 @@ async def test_template_error(
|
|||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
subentry,
|
subentry,
|
||||||
data={
|
data={
|
||||||
|
**subentry.data,
|
||||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -593,6 +602,8 @@ async def test_conversation_agent(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
mock_init_component,
|
mock_init_component,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test OllamaConversationEntity."""
|
"""Test OllamaConversationEntity."""
|
||||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
agent = conversation.get_agent_manager(hass).async_get_agent(
|
||||||
@ -604,6 +615,24 @@ async def test_conversation_agent(
|
|||||||
assert state
|
assert state
|
||||||
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
|
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
|
||||||
|
|
||||||
|
entity_entry = entity_registry.async_get("conversation.ollama_conversation")
|
||||||
|
assert entity_entry
|
||||||
|
subentry = mock_config_entry.subentries.get(entity_entry.unique_id)
|
||||||
|
assert subentry
|
||||||
|
assert entity_entry.original_name == subentry.title
|
||||||
|
|
||||||
|
device_entry = device_registry.async_get(entity_entry.device_id)
|
||||||
|
assert device_entry
|
||||||
|
|
||||||
|
assert device_entry.identifiers == {(ollama.DOMAIN, subentry.subentry_id)}
|
||||||
|
assert device_entry.name == subentry.title
|
||||||
|
assert device_entry.manufacturer == "Ollama"
|
||||||
|
assert device_entry.entry_type == dr.DeviceEntryType.SERVICE
|
||||||
|
|
||||||
|
model, _, version = subentry.data[ollama.CONF_MODEL].partition(":")
|
||||||
|
assert device_entry.model == model
|
||||||
|
assert device_entry.sw_version == version
|
||||||
|
|
||||||
|
|
||||||
async def test_conversation_agent_with_assist(
|
async def test_conversation_agent_with_assist(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -679,6 +708,7 @@ async def test_reasoning_filter(
|
|||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
subentry,
|
subentry,
|
||||||
data={
|
data={
|
||||||
|
**subentry.data,
|
||||||
ollama.CONF_THINK: think,
|
ollama.CONF_THINK: think,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -9,13 +9,26 @@ from homeassistant.components import ollama
|
|||||||
from homeassistant.components.ollama.const import DOMAIN
|
from homeassistant.components.ollama.const import DOMAIN
|
||||||
from homeassistant.config_entries import ConfigSubentryData
|
from homeassistant.config_entries import ConfigSubentryData
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
from homeassistant.helpers import device_registry as dr, entity_registry as er, llm
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import TEST_OPTIONS, TEST_USER_DATA
|
from . import TEST_OPTIONS
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
V1_TEST_USER_DATA = {
|
||||||
|
ollama.CONF_URL: "http://localhost:11434",
|
||||||
|
ollama.CONF_MODEL: "test_model:latest",
|
||||||
|
}
|
||||||
|
|
||||||
|
V1_TEST_OPTIONS = {
|
||||||
|
ollama.CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||||
|
ollama.CONF_MAX_HISTORY: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
V21_TEST_USER_DATA = V1_TEST_USER_DATA
|
||||||
|
V21_TEST_OPTIONS = V1_TEST_OPTIONS
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
@ -41,17 +54,17 @@ async def test_init_error(
|
|||||||
assert error in caplog.text
|
assert error in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v1_to_v2(
|
async def test_migration_from_v1(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migration from version 1 to version 2."""
|
"""Test migration from version 1."""
|
||||||
# Create a v1 config entry with conversation options and an entity
|
# Create a v1 config entry with conversation options and an entity
|
||||||
mock_config_entry = MockConfigEntry(
|
mock_config_entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data=TEST_USER_DATA,
|
data=V1_TEST_USER_DATA,
|
||||||
options=TEST_OPTIONS,
|
options=V1_TEST_OPTIONS,
|
||||||
version=1,
|
version=1,
|
||||||
title="llama-3.2-8b",
|
title="llama-3.2-8b",
|
||||||
)
|
)
|
||||||
@ -81,9 +94,10 @@ async def test_migration_from_v1_to_v2(
|
|||||||
):
|
):
|
||||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
|
||||||
assert mock_config_entry.version == 2
|
assert mock_config_entry.version == 3
|
||||||
assert mock_config_entry.minor_version == 2
|
assert mock_config_entry.minor_version == 1
|
||||||
assert mock_config_entry.data == TEST_USER_DATA
|
# After migration, parent entry should only have URL
|
||||||
|
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||||
assert mock_config_entry.options == {}
|
assert mock_config_entry.options == {}
|
||||||
|
|
||||||
assert len(mock_config_entry.subentries) == 1
|
assert len(mock_config_entry.subentries) == 1
|
||||||
@ -92,7 +106,9 @@ async def test_migration_from_v1_to_v2(
|
|||||||
assert subentry.unique_id is None
|
assert subentry.unique_id is None
|
||||||
assert subentry.title == "llama-3.2-8b"
|
assert subentry.title == "llama-3.2-8b"
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
assert subentry.data == TEST_OPTIONS
|
# Subentry should now include the model from the original options
|
||||||
|
expected_subentry_data = TEST_OPTIONS.copy()
|
||||||
|
assert subentry.data == expected_subentry_data
|
||||||
|
|
||||||
migrated_entity = entity_registry.async_get(entity.entity_id)
|
migrated_entity = entity_registry.async_get(entity.entity_id)
|
||||||
assert migrated_entity is not None
|
assert migrated_entity is not None
|
||||||
@ -117,17 +133,17 @@ async def test_migration_from_v1_to_v2(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v1_to_v2_with_multiple_urls(
|
async def test_migration_from_v1_with_multiple_urls(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migration from version 1 to version 2 with different URLs."""
|
"""Test migration from version 1 with different URLs."""
|
||||||
# Create two v1 config entries with different URLs
|
# Create two v1 config entries with different URLs
|
||||||
mock_config_entry = MockConfigEntry(
|
mock_config_entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
||||||
options=TEST_OPTIONS,
|
options=V1_TEST_OPTIONS,
|
||||||
version=1,
|
version=1,
|
||||||
title="Ollama 1",
|
title="Ollama 1",
|
||||||
)
|
)
|
||||||
@ -135,7 +151,7 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
|
|||||||
mock_config_entry_2 = MockConfigEntry(
|
mock_config_entry_2 = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
|
data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
|
||||||
options=TEST_OPTIONS,
|
options=V1_TEST_OPTIONS,
|
||||||
version=1,
|
version=1,
|
||||||
title="Ollama 2",
|
title="Ollama 2",
|
||||||
)
|
)
|
||||||
@ -187,13 +203,16 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
|
|||||||
assert len(entries) == 2
|
assert len(entries) == 2
|
||||||
|
|
||||||
for idx, entry in enumerate(entries):
|
for idx, entry in enumerate(entries):
|
||||||
assert entry.version == 2
|
assert entry.version == 3
|
||||||
assert entry.minor_version == 2
|
assert entry.minor_version == 1
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert len(entry.subentries) == 1
|
assert len(entry.subentries) == 1
|
||||||
subentry = list(entry.subentries.values())[0]
|
subentry = list(entry.subentries.values())[0]
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
assert subentry.data == TEST_OPTIONS
|
# Subentry should include the model along with the original options
|
||||||
|
expected_subentry_data = TEST_OPTIONS.copy()
|
||||||
|
expected_subentry_data["model"] = "llama3.2:latest"
|
||||||
|
assert subentry.data == expected_subentry_data
|
||||||
assert subentry.title == f"Ollama {idx + 1}"
|
assert subentry.title == f"Ollama {idx + 1}"
|
||||||
|
|
||||||
dev = device_registry.async_get_device(
|
dev = device_registry.async_get_device(
|
||||||
@ -204,17 +223,17 @@ async def test_migration_from_v1_to_v2_with_multiple_urls(
|
|||||||
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
|
assert dev.config_entries_subentries == {entry.entry_id: {subentry.subentry_id}}
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v1_to_v2_with_same_urls(
|
async def test_migration_from_v1_with_same_urls(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migration from version 1 to version 2 with same URLs consolidates entries."""
|
"""Test migration from version 1 with same URLs consolidates entries."""
|
||||||
# Create two v1 config entries with the same URL
|
# Create two v1 config entries with the same URL
|
||||||
mock_config_entry = MockConfigEntry(
|
mock_config_entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
||||||
options=TEST_OPTIONS,
|
options=V1_TEST_OPTIONS,
|
||||||
version=1,
|
version=1,
|
||||||
title="Ollama",
|
title="Ollama",
|
||||||
)
|
)
|
||||||
@ -222,7 +241,7 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
|||||||
mock_config_entry_2 = MockConfigEntry(
|
mock_config_entry_2 = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
|
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
|
||||||
options=TEST_OPTIONS,
|
options=V1_TEST_OPTIONS,
|
||||||
version=1,
|
version=1,
|
||||||
title="Ollama 2",
|
title="Ollama 2",
|
||||||
)
|
)
|
||||||
@ -275,8 +294,8 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
|||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
|
||||||
entry = entries[0]
|
entry = entries[0]
|
||||||
assert entry.version == 2
|
assert entry.version == 3
|
||||||
assert entry.minor_version == 2
|
assert entry.minor_version == 1
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert len(entry.subentries) == 2 # Two subentries from the two original entries
|
assert len(entry.subentries) == 2 # Two subentries from the two original entries
|
||||||
|
|
||||||
@ -288,7 +307,10 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
|||||||
|
|
||||||
for subentry in subentries:
|
for subentry in subentries:
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
assert subentry.data == TEST_OPTIONS
|
# Subentry should include the model along with the original options
|
||||||
|
expected_subentry_data = TEST_OPTIONS.copy()
|
||||||
|
expected_subentry_data["model"] = "llama3.2:latest"
|
||||||
|
assert subentry.data == expected_subentry_data
|
||||||
|
|
||||||
# Check devices were migrated correctly
|
# Check devices were migrated correctly
|
||||||
dev = device_registry.async_get_device(
|
dev = device_registry.async_get_device(
|
||||||
@ -301,12 +323,12 @@ async def test_migration_from_v1_to_v2_with_same_urls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_migration_from_v2_1_to_v2_2(
|
async def test_migration_from_v2_1(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test migration from version 2.1 to version 2.2.
|
"""Test migration from version 2.1.
|
||||||
|
|
||||||
This tests we clean up the broken migration in Home Assistant Core
|
This tests we clean up the broken migration in Home Assistant Core
|
||||||
2025.7.0b0-2025.7.0b1:
|
2025.7.0b0-2025.7.0b1:
|
||||||
@ -315,20 +337,20 @@ async def test_migration_from_v2_1_to_v2_2(
|
|||||||
# Create a v2.1 config entry with 2 subentries, devices and entities
|
# Create a v2.1 config entry with 2 subentries, devices and entities
|
||||||
mock_config_entry = MockConfigEntry(
|
mock_config_entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
data=TEST_USER_DATA,
|
data=V21_TEST_USER_DATA,
|
||||||
entry_id="mock_entry_id",
|
entry_id="mock_entry_id",
|
||||||
version=2,
|
version=2,
|
||||||
minor_version=1,
|
minor_version=1,
|
||||||
subentries_data=[
|
subentries_data=[
|
||||||
ConfigSubentryData(
|
ConfigSubentryData(
|
||||||
data=TEST_OPTIONS,
|
data=V21_TEST_OPTIONS,
|
||||||
subentry_id="mock_id_1",
|
subentry_id="mock_id_1",
|
||||||
subentry_type="conversation",
|
subentry_type="conversation",
|
||||||
title="Ollama",
|
title="Ollama",
|
||||||
unique_id=None,
|
unique_id=None,
|
||||||
),
|
),
|
||||||
ConfigSubentryData(
|
ConfigSubentryData(
|
||||||
data=TEST_OPTIONS,
|
data=V21_TEST_OPTIONS,
|
||||||
subentry_id="mock_id_2",
|
subentry_id="mock_id_2",
|
||||||
subentry_type="conversation",
|
subentry_type="conversation",
|
||||||
title="Ollama 2",
|
title="Ollama 2",
|
||||||
@ -392,8 +414,8 @@ async def test_migration_from_v2_1_to_v2_2(
|
|||||||
entries = hass.config_entries.async_entries(DOMAIN)
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
entry = entries[0]
|
entry = entries[0]
|
||||||
assert entry.version == 2
|
assert entry.version == 3
|
||||||
assert entry.minor_version == 2
|
assert entry.minor_version == 1
|
||||||
assert not entry.options
|
assert not entry.options
|
||||||
assert entry.title == "Ollama"
|
assert entry.title == "Ollama"
|
||||||
assert len(entry.subentries) == 2
|
assert len(entry.subentries) == 2
|
||||||
@ -405,6 +427,7 @@ async def test_migration_from_v2_1_to_v2_2(
|
|||||||
assert len(conversation_subentries) == 2
|
assert len(conversation_subentries) == 2
|
||||||
for subentry in conversation_subentries:
|
for subentry in conversation_subentries:
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
|
# Since TEST_USER_DATA no longer has a model, subentry data should be TEST_OPTIONS
|
||||||
assert subentry.data == TEST_OPTIONS
|
assert subentry.data == TEST_OPTIONS
|
||||||
assert "Ollama" in subentry.title
|
assert "Ollama" in subentry.title
|
||||||
|
|
||||||
@ -450,3 +473,45 @@ async def test_migration_from_v2_1_to_v2_2(
|
|||||||
assert device.config_entries_subentries == {
|
assert device.config_entries_subentries == {
|
||||||
mock_config_entry.entry_id: {subentry.subentry_id}
|
mock_config_entry.entry_id: {subentry.subentry_id}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
|
||||||
|
"""Test migration from version 2.2."""
|
||||||
|
subentry_data = ConfigSubentryData(
|
||||||
|
data=V21_TEST_USER_DATA,
|
||||||
|
subentry_type="conversation",
|
||||||
|
title="Test Conversation",
|
||||||
|
unique_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
data={
|
||||||
|
ollama.CONF_URL: "http://localhost:11434",
|
||||||
|
ollama.CONF_MODEL: "test_model:latest", # Model still in main data
|
||||||
|
},
|
||||||
|
version=2,
|
||||||
|
minor_version=2,
|
||||||
|
subentries_data=[subentry_data],
|
||||||
|
)
|
||||||
|
mock_config_entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.ollama.async_setup_entry",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
|
||||||
|
# Check migration to v3.1
|
||||||
|
assert mock_config_entry.version == 3
|
||||||
|
assert mock_config_entry.minor_version == 1
|
||||||
|
|
||||||
|
# Check that model was moved from main data to subentry
|
||||||
|
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||||
|
assert len(mock_config_entry.subentries) == 1
|
||||||
|
|
||||||
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
assert subentry.data == {
|
||||||
|
**V21_TEST_USER_DATA,
|
||||||
|
ollama.CONF_MODEL: "test_model:latest",
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user