Convert Ollama to subentries (#147286)

* Convert Ollama to subentries

* Add latest changes from Google subentries

* Move config entry type to init
This commit is contained in:
Paulus Schoutsen 2025-06-24 16:13:34 -04:00 committed by GitHub
parent 5a20ef3f3f
commit f735331699
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 625 additions and 139 deletions

View File

@ -8,11 +8,16 @@ import logging
import httpx import httpx
import ollama import ollama
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_URL, Platform from homeassistant.const import CONF_URL, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.ssl import get_default_context from homeassistant.util.ssl import get_default_context
from .const import ( from .const import (
@ -42,8 +47,16 @@ __all__ = [
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,) PLATFORMS = (Platform.CONVERSATION,)
type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient]
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Ollama."""
await async_migrate_integration(hass)
return True
async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool:
"""Set up Ollama from a config entry.""" """Set up Ollama from a config entry."""
settings = {**entry.data, **entry.options} settings = {**entry.data, **entry.options}
client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context()) client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context())
@ -53,8 +66,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
except (TimeoutError, httpx.ConnectError) as err: except (TimeoutError, httpx.ConnectError) as err:
raise ConfigEntryNotReady(err) from err raise ConfigEntryNotReady(err) from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client entry.runtime_data = client
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True return True
@ -63,5 +75,69 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Ollama.""" """Unload Ollama."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False return False
hass.data[DOMAIN].pop(entry.entry_id)
return True return True
async def async_migrate_integration(hass: HomeAssistant) -> None:
"""Migrate integration entry structure."""
entries = hass.config_entries.async_entries(DOMAIN)
if not any(entry.version == 1 for entry in entries):
return
api_keys_entries: dict[str, ConfigEntry] = {}
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
for entry in entries:
use_existing = False
subentry = ConfigSubentry(
data=entry.options,
subentry_type="conversation",
title=entry.title,
unique_id=None,
)
if entry.data[CONF_URL] not in api_keys_entries:
use_existing = True
api_keys_entries[entry.data[CONF_URL]] = entry
parent_entry = api_keys_entries[entry.data[CONF_URL]]
hass.config_entries.async_add_subentry(parent_entry, subentry)
conversation_entity = entity_registry.async_get_entity_id(
"conversation",
DOMAIN,
entry.entry_id,
)
if conversation_entity is not None:
entity_registry.async_update_entity(
conversation_entity,
config_entry_id=parent_entry.entry_id,
config_subentry_id=subentry.subentry_id,
new_unique_id=subentry.subentry_id,
)
device = device_registry.async_get_device(
identifiers={(DOMAIN, entry.entry_id)}
)
if device is not None:
device_registry.async_update_device(
device.id,
new_identifiers={(DOMAIN, subentry.subentry_id)},
add_config_subentry_id=subentry.subentry_id,
add_config_entry_id=parent_entry.entry_id,
)
if parent_entry.entry_id != entry.entry_id:
device_registry.async_update_device(
device.id,
remove_config_entry_id=entry.entry_id,
)
if not use_existing:
await hass.config_entries.async_remove(entry.entry_id)
else:
hass.config_entries.async_update_entry(
entry,
options={},
version=2,
)

View File

@ -14,12 +14,14 @@ import voluptuous as vol
from homeassistant.config_entries import ( from homeassistant.config_entries import (
ConfigEntry, ConfigEntry,
ConfigEntryState,
ConfigFlow, ConfigFlow,
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, ConfigSubentryFlow,
SubentryFlowResult,
) )
from homeassistant.const import CONF_LLM_HASS_API, CONF_URL from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm from homeassistant.helpers import llm
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
BooleanSelector, BooleanSelector,
@ -70,7 +72,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Ollama.""" """Handle a config flow for Ollama."""
VERSION = 1 VERSION = 2
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize config flow.""" """Initialize config flow."""
@ -94,6 +96,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
errors = {} errors = {}
self._async_abort_entries_match({CONF_URL: self.url})
try: try:
self.client = ollama.AsyncClient( self.client = ollama.AsyncClient(
host=self.url, verify=get_default_context() host=self.url, verify=get_default_context()
@ -146,8 +150,16 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
return await self.async_step_download() return await self.async_step_download()
return self.async_create_entry( return self.async_create_entry(
title=_get_title(self.model), title=self.url,
data={CONF_URL: self.url, CONF_MODEL: self.model}, data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": _get_title(self.model),
"unique_id": None,
}
],
) )
async def async_step_download( async def async_step_download(
@ -189,6 +201,14 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry( return self.async_create_entry(
title=_get_title(self.model), title=_get_title(self.model),
data={CONF_URL: self.url, CONF_MODEL: self.model}, data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": _get_title(self.model),
"unique_id": None,
}
],
) )
async def async_step_failed( async def async_step_failed(
@ -197,41 +217,62 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Step after model downloading has failed.""" """Step after model downloading has failed."""
return self.async_abort(reason="download_failed") return self.async_abort(reason="download_failed")
@staticmethod @classmethod
def async_get_options_flow( @callback
config_entry: ConfigEntry, def async_get_supported_subentry_types(
) -> OptionsFlow: cls, config_entry: ConfigEntry
"""Create the options flow.""" ) -> dict[str, type[ConfigSubentryFlow]]:
return OllamaOptionsFlow(config_entry) """Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
class OllamaOptionsFlow(OptionsFlow): class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Ollama options flow.""" """Flow for managing conversation subentries."""
def __init__(self, config_entry: ConfigEntry) -> None: @property
"""Initialize options flow.""" def _is_new(self) -> bool:
self.url: str = config_entry.data[CONF_URL] """Return if this is a new subentry."""
self.model: str = config_entry.data[CONF_MODEL] return self.source == "user"
async def async_step_init( async def async_step_set_options(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> SubentryFlowResult:
"""Manage the options.""" """Set conversation options."""
if user_input is not None: # abort if entry is not loaded
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
errors: dict[str, str] = {}
if user_input is None:
if self._is_new:
options = {}
else:
options = self._get_reconfigure_subentry().data.copy()
elif self._is_new:
return self.async_create_entry( return self.async_create_entry(
title=_get_title(self.model), data=user_input title=user_input.pop(CONF_NAME),
data=user_input,
)
else:
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=user_input,
) )
options: Mapping[str, Any] = self.config_entry.options or {} schema = ollama_config_option_schema(self.hass, self._is_new, options)
schema = ollama_config_option_schema(self.hass, options)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="set_options", data_schema=vol.Schema(schema), errors=errors
data_schema=vol.Schema(schema),
) )
async_step_user = async_step_set_options
async_step_reconfigure = async_step_set_options
def ollama_config_option_schema( def ollama_config_option_schema(
hass: HomeAssistant, options: Mapping[str, Any] hass: HomeAssistant, is_new: bool, options: Mapping[str, Any]
) -> dict: ) -> dict:
"""Ollama options schema.""" """Ollama options schema."""
hass_apis: list[SelectOptionDict] = [ hass_apis: list[SelectOptionDict] = [
@ -242,54 +283,72 @@ def ollama_config_option_schema(
for api in llm.async_get_apis(hass) for api in llm.async_get_apis(hass)
] ]
return { if is_new:
vol.Optional( schema: dict[vol.Required | vol.Optional, Any] = {
CONF_PROMPT, vol.Required(CONF_NAME, default="Ollama Conversation"): str,
description={ }
"suggested_value": options.get( else:
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT schema = {}
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
vol.Optional(
CONF_NUM_CTX,
description={
"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)
},
): NumberSelector(
NumberSelectorConfig(
min=MIN_NUM_CTX,
max=MAX_NUM_CTX,
step=1,
mode=NumberSelectorMode.BOX,
) )
}, ),
): TemplateSelector(), vol.Optional(
vol.Optional( CONF_MAX_HISTORY,
CONF_LLM_HASS_API, description={
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, "suggested_value": options.get(
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY
vol.Optional( )
CONF_NUM_CTX, },
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, ): NumberSelector(
): NumberSelector( NumberSelectorConfig(
NumberSelectorConfig( min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX )
) ),
), vol.Optional(
vol.Optional( CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, description={
description={ "suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)
"suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY) },
}, ): NumberSelector(
): NumberSelector( NumberSelectorConfig(
NumberSelectorConfig( min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX )
) ),
), vol.Optional(
vol.Optional( CONF_THINK,
CONF_KEEP_ALIVE, description={
description={ "suggested_value": options.get("think", DEFAULT_THINK),
"suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE) },
}, ): BooleanSelector(),
): NumberSelector( }
NumberSelectorConfig( )
min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
) return schema
),
vol.Optional(
CONF_THINK,
description={
"suggested_value": options.get("think", DEFAULT_THINK),
},
): BooleanSelector(),
}
def _get_title(model: str) -> str: def _get_title(model: str) -> str:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, AsyncIterator, Callable
import json import json
import logging import logging
from typing import Any, Literal from typing import Any, Literal
@ -11,13 +11,14 @@ import ollama
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OllamaConfigEntry
from .const import ( from .const import (
CONF_KEEP_ALIVE, CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, CONF_MAX_HISTORY,
@ -40,12 +41,18 @@ _LOGGER = logging.getLogger(__name__)
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: ConfigEntry, config_entry: OllamaConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback, async_add_entities: AddConfigEntryEntitiesCallback,
) -> None: ) -> None:
"""Set up conversation entities.""" """Set up conversation entities."""
agent = OllamaConversationEntity(config_entry) for subentry in config_entry.subentries.values():
async_add_entities([agent]) if subentry.subentry_type != "conversation":
continue
async_add_entities(
[OllamaConversationEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
def _format_tool( def _format_tool(
@ -130,7 +137,7 @@ def _convert_content(
async def _transform_stream( async def _transform_stream(
result: AsyncGenerator[ollama.Message], result: AsyncIterator[ollama.ChatResponse],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform the response stream into HA format. """Transform the response stream into HA format.
@ -174,17 +181,22 @@ class OllamaConversationEntity(
): ):
"""Ollama conversation agent.""" """Ollama conversation agent."""
_attr_has_entity_name = True
_attr_supports_streaming = True _attr_supports_streaming = True
def __init__(self, entry: ConfigEntry) -> None: def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self.subentry = subentry
# conversation id -> message history self._attr_name = subentry.title
self._attr_name = entry.title self._attr_unique_id = subentry.subentry_id
self._attr_unique_id = entry.entry_id self._attr_device_info = dr.DeviceInfo(
if self.entry.options.get(CONF_LLM_HASS_API): identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
manufacturer="Ollama",
model=entry.data[CONF_MODEL],
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = ( self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL conversation.ConversationEntityFeature.CONTROL
) )
@ -216,7 +228,7 @@ class OllamaConversationEntity(
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
settings = {**self.entry.data, **self.entry.options} settings = {**self.entry.data, **self.subentry.data}
try: try:
await chat_log.async_provide_llm_data( await chat_log.async_provide_llm_data(
@ -248,9 +260,9 @@ class OllamaConversationEntity(
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> None: ) -> None:
"""Generate an answer for the chat log.""" """Generate an answer for the chat log."""
settings = {**self.entry.data, **self.entry.options} settings = {**self.entry.data, **self.subentry.data}
client = self.hass.data[DOMAIN][self.entry.entry_id] client = self.entry.runtime_data
model = settings[CONF_MODEL] model = settings[CONF_MODEL]
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None

View File

@ -12,7 +12,8 @@
} }
}, },
"abort": { "abort": {
"download_failed": "Model downloading failed" "download_failed": "Model downloading failed",
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
}, },
"error": { "error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
@ -22,23 +23,35 @@
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details." "download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
} }
}, },
"options": { "config_subentries": {
"step": { "conversation": {
"init": { "initiate_flow": {
"data": { "user": "Add conversation agent",
"prompt": "Instructions", "reconfigure": "Reconfigure conversation agent"
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", },
"max_history": "Max history messages", "entry_type": "Conversation agent",
"num_ctx": "Context window size", "step": {
"keep_alive": "Keep alive", "set_options": {
"think": "Think before responding" "data": {
}, "name": "[%key:common::config_flow::data::name%]",
"data_description": { "prompt": "Instructions",
"prompt": "Instruct how the LLM should respond. This can be a template.", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.", "max_history": "Max history messages",
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.", "num_ctx": "Context window size",
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency." "keep_alive": "Keep alive",
"think": "Think before responding"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template.",
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.",
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
}
} }
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
"entry_not_loaded": "Cannot add things while the configuration is disabled."
} }
} }
} }

View File

@ -30,7 +30,15 @@ def mock_config_entry(
entry = MockConfigEntry( entry = MockConfigEntry(
domain=ollama.DOMAIN, domain=ollama.DOMAIN,
data=TEST_USER_DATA, data=TEST_USER_DATA,
options=mock_config_entry_options, version=2,
subentries_data=[
{
"data": mock_config_entry_options,
"subentry_type": "conversation",
"title": "Ollama Conversation",
"unique_id": None,
}
],
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
return entry return entry
@ -41,8 +49,10 @@ def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Mock a config entry with assist.""" """Mock a config entry with assist."""
hass.config_entries.async_update_entry( hass.config_entries.async_update_subentry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} mock_config_entry,
next(iter(mock_config_entry.subentries.values())),
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
) )
return mock_config_entry return mock_config_entry

View File

@ -63,6 +63,37 @@ async def test_form(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_duplicate_entry(hass: HomeAssistant) -> None:
"""Test we abort on duplicate config entry."""
MockConfigEntry(
domain=ollama.DOMAIN,
data={
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model",
},
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert not result["errors"]
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
return_value={"models": [{"model": "test_model"}]},
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
ollama.CONF_URL: "http://localhost:11434",
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
async def test_form_need_download(hass: HomeAssistant) -> None: async def test_form_need_download(hass: HomeAssistant) -> None:
"""Test flow when a model needs to be downloaded.""" """Test flow when a model needs to be downloaded."""
# Pretend we already set up a config entry. # Pretend we already set up a config entry.
@ -155,14 +186,21 @@ async def test_form_need_download(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_options( async def test_subentry_options(
hass: HomeAssistant, mock_config_entry, mock_init_component hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None: ) -> None:
"""Test the options form.""" """Test the subentry options form."""
options_flow = await hass.config_entries.options.async_init( subentry = next(iter(mock_config_entry.subentries.values()))
mock_config_entry.entry_id
# Test reconfiguration
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
) )
options = await hass.config_entries.options.async_configure(
assert options_flow["type"] is FlowResultType.FORM
assert options_flow["step_id"] == "set_options"
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"], options_flow["flow_id"],
{ {
ollama.CONF_PROMPT: "test prompt", ollama.CONF_PROMPT: "test prompt",
@ -172,8 +210,10 @@ async def test_options(
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == { assert options["type"] is FlowResultType.ABORT
assert options["reason"] == "reconfigure_successful"
assert subentry.data == {
ollama.CONF_PROMPT: "test prompt", ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100, ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768, ollama.CONF_NUM_CTX: 32768,
@ -181,6 +221,22 @@ async def test_options(
} }
async def test_creating_conversation_subentry_not_loaded(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation subentry when entry is not loaded."""
await hass.config_entries.async_unload(mock_config_entry.entry_id)
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "entry_not_loaded"
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "error"), ("side_effect", "error"),
[ [

View File

@ -35,7 +35,7 @@ async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]:
yield msg yield msg
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) @pytest.mark.parametrize("agent_id", [None, "conversation.ollama_conversation"])
async def test_chat( async def test_chat(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@ -149,9 +149,11 @@ async def test_template_variables(
mock_user.id = "12345" mock_user.id = "12345"
mock_user.name = "Test User" mock_user.name = "Test User"
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
data={
"prompt": ( "prompt": (
"The user name is {{ user_name }}. " "The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}." "The user id is {{ llm_context.context.user_id }}."
@ -382,10 +384,12 @@ async def test_unknown_hass_api(
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test when we reference an API that no longer exists.""" """Test when we reference an API that no longer exists."""
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
**mock_config_entry.options, data={
**subentry.data,
CONF_LLM_HASS_API: "non-existing", CONF_LLM_HASS_API: "non-existing",
}, },
) )
@ -518,8 +522,9 @@ async def test_message_history_unlimited(
with ( with (
patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat, patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
): ):
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} hass.config_entries.async_update_subentry(
mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0}
) )
for i in range(100): for i in range(100):
result = await conversation.async_converse( result = await conversation.async_converse(
@ -563,9 +568,11 @@ async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test that template error handling works.""" """Test that template error handling works."""
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
data={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
}, },
) )
@ -593,7 +600,7 @@ async def test_conversation_agent(
) )
assert agent.supported_languages == MATCH_ALL assert agent.supported_languages == MATCH_ALL
state = hass.states.get("conversation.mock_title") state = hass.states.get("conversation.ollama_conversation")
assert state assert state
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0 assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
@ -609,7 +616,7 @@ async def test_conversation_agent_with_assist(
) )
assert agent.supported_languages == MATCH_ALL assert agent.supported_languages == MATCH_ALL
state = hass.states.get("conversation.mock_title") state = hass.states.get("conversation.ollama_conversation")
assert state assert state
assert ( assert (
state.attributes[ATTR_SUPPORTED_FEATURES] state.attributes[ATTR_SUPPORTED_FEATURES]
@ -642,7 +649,7 @@ async def test_options(
"test message", "test message",
None, None,
Context(), Context(),
agent_id="conversation.mock_title", agent_id="conversation.ollama_conversation",
) )
assert mock_chat.call_count == 1 assert mock_chat.call_count == 1
@ -667,9 +674,11 @@ async def test_reasoning_filter(
entry = MockConfigEntry() entry = MockConfigEntry()
entry.add_to_hass(hass) entry.add_to_hass(hass)
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
data={
ollama.CONF_THINK: think, ollama.CONF_THINK: think,
}, },
) )

View File

@ -6,9 +6,13 @@ from httpx import ConnectError
import pytest import pytest
from homeassistant.components import ollama from homeassistant.components import ollama
from homeassistant.components.ollama.const import DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import TEST_OPTIONS, TEST_USER_DATA
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -34,3 +38,250 @@ async def test_init_error(
assert await async_setup_component(hass, ollama.DOMAIN, {}) assert await async_setup_component(hass, ollama.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
assert error in caplog.text assert error in caplog.text
async def test_migration_from_v1_to_v2(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2."""
# Create a v1 config entry with conversation options and an entity
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data=TEST_USER_DATA,
options=TEST_OPTIONS,
version=1,
title="llama-3.2-8b",
)
mock_config_entry.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, mock_config_entry.entry_id)},
name=mock_config_entry.title,
manufacturer="Ollama",
model="Ollama",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity = entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry.entry_id,
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="llama_3_2_8b",
)
# Run migration
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
assert mock_config_entry.version == 2
assert mock_config_entry.data == TEST_USER_DATA
assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 1
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.unique_id is None
assert subentry.title == "llama-3.2-8b"
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity is not None
assert migrated_entity.config_entry_id == mock_config_entry.entry_id
assert migrated_entity.config_subentry_id == subentry.subentry_id
assert migrated_entity.unique_id == subentry.subentry_id
# Check device migration
assert not device_registry.async_get_device(
identifiers={(DOMAIN, mock_config_entry.entry_id)}
)
assert (
migrated_device := device_registry.async_get_device(
identifiers={(DOMAIN, subentry.subentry_id)}
)
)
assert migrated_device.identifiers == {(DOMAIN, subentry.subentry_id)}
assert migrated_device.id == device.id
async def test_migration_from_v1_to_v2_with_multiple_urls(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with different URLs."""
# Create two v1 config entries with different URLs
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
version=1,
title="Ollama 1",
)
mock_config_entry.add_to_hass(hass)
mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
version=1,
title="Ollama 2",
)
mock_config_entry_2.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, mock_config_entry.entry_id)},
name=mock_config_entry.title,
manufacturer="Ollama",
model="Ollama 1",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry.entry_id,
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="ollama_1",
)
device_2 = device_registry.async_get_or_create(
config_entry_id=mock_config_entry_2.entry_id,
identifiers={(DOMAIN, mock_config_entry_2.entry_id)},
name=mock_config_entry_2.title,
manufacturer="Ollama",
model="Ollama 2",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry_2.entry_id,
config_entry=mock_config_entry_2,
device_id=device_2.id,
suggested_object_id="ollama_2",
)
# Run migration
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 2
for idx, entry in enumerate(entries):
assert entry.version == 2
assert not entry.options
assert len(entry.subentries) == 1
subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
assert subentry.title == f"Ollama {idx + 1}"
dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
)
assert dev is not None
async def test_migration_from_v1_to_v2_with_same_urls(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with same URLs consolidates entries."""
# Create two v1 config entries with the same URL
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
version=1,
title="Ollama",
)
mock_config_entry.add_to_hass(hass)
mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
options=TEST_OPTIONS,
version=1,
title="Ollama 2",
)
mock_config_entry_2.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, mock_config_entry.entry_id)},
name=mock_config_entry.title,
manufacturer="Ollama",
model="Ollama",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry.entry_id,
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="ollama",
)
device_2 = device_registry.async_get_or_create(
config_entry_id=mock_config_entry_2.entry_id,
identifiers={(DOMAIN, mock_config_entry_2.entry_id)},
name=mock_config_entry_2.title,
manufacturer="Ollama",
model="Ollama",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry_2.entry_id,
config_entry=mock_config_entry_2,
device_id=device_2.id,
suggested_object_id="ollama_2",
)
# Run migration
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Should have only one entry left (consolidated)
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.version == 2
assert not entry.options
assert len(entry.subentries) == 2 # Two subentries from the two original entries
# Check both subentries exist with correct data
subentries = list(entry.subentries.values())
titles = [sub.title for sub in subentries]
assert "Ollama" in titles
assert "Ollama 2" in titles
for subentry in subentries:
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
# Check devices were migrated correctly
dev = device_registry.async_get_device(
identifiers={(DOMAIN, subentry.subentry_id)}
)
assert dev is not None