mirror of
https://github.com/home-assistant/core.git
synced 2025-07-10 14:57:09 +00:00
Convert Ollama to subentries (#147286)
* Convert Ollama to subentries * Add latest changes from Google subentries * Move config entry type to init
This commit is contained in:
parent
5a20ef3f3f
commit
f735331699
@ -8,11 +8,16 @@ import logging
|
|||||||
import httpx
|
import httpx
|
||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
from homeassistant.const import CONF_URL, Platform
|
from homeassistant.const import CONF_URL, Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady
|
from homeassistant.exceptions import ConfigEntryNotReady
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import (
|
||||||
|
config_validation as cv,
|
||||||
|
device_registry as dr,
|
||||||
|
entity_registry as er,
|
||||||
|
)
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.util.ssl import get_default_context
|
from homeassistant.util.ssl import get_default_context
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -42,8 +47,16 @@ __all__ = [
|
|||||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||||
PLATFORMS = (Platform.CONVERSATION,)
|
PLATFORMS = (Platform.CONVERSATION,)
|
||||||
|
|
||||||
|
type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient]
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|
||||||
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
"""Set up Ollama."""
|
||||||
|
await async_migrate_integration(hass)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool:
|
||||||
"""Set up Ollama from a config entry."""
|
"""Set up Ollama from a config entry."""
|
||||||
settings = {**entry.data, **entry.options}
|
settings = {**entry.data, **entry.options}
|
||||||
client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context())
|
client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context())
|
||||||
@ -53,8 +66,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
except (TimeoutError, httpx.ConnectError) as err:
|
except (TimeoutError, httpx.ConnectError) as err:
|
||||||
raise ConfigEntryNotReady(err) from err
|
raise ConfigEntryNotReady(err) from err
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
entry.runtime_data = client
|
||||||
|
|
||||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -63,5 +75,69 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
"""Unload Ollama."""
|
"""Unload Ollama."""
|
||||||
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||||
return False
|
return False
|
||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_migrate_integration(hass: HomeAssistant) -> None:
|
||||||
|
"""Migrate integration entry structure."""
|
||||||
|
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
if not any(entry.version == 1 for entry in entries):
|
||||||
|
return
|
||||||
|
|
||||||
|
api_keys_entries: dict[str, ConfigEntry] = {}
|
||||||
|
entity_registry = er.async_get(hass)
|
||||||
|
device_registry = dr.async_get(hass)
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
use_existing = False
|
||||||
|
subentry = ConfigSubentry(
|
||||||
|
data=entry.options,
|
||||||
|
subentry_type="conversation",
|
||||||
|
title=entry.title,
|
||||||
|
unique_id=None,
|
||||||
|
)
|
||||||
|
if entry.data[CONF_URL] not in api_keys_entries:
|
||||||
|
use_existing = True
|
||||||
|
api_keys_entries[entry.data[CONF_URL]] = entry
|
||||||
|
|
||||||
|
parent_entry = api_keys_entries[entry.data[CONF_URL]]
|
||||||
|
|
||||||
|
hass.config_entries.async_add_subentry(parent_entry, subentry)
|
||||||
|
conversation_entity = entity_registry.async_get_entity_id(
|
||||||
|
"conversation",
|
||||||
|
DOMAIN,
|
||||||
|
entry.entry_id,
|
||||||
|
)
|
||||||
|
if conversation_entity is not None:
|
||||||
|
entity_registry.async_update_entity(
|
||||||
|
conversation_entity,
|
||||||
|
config_entry_id=parent_entry.entry_id,
|
||||||
|
config_subentry_id=subentry.subentry_id,
|
||||||
|
new_unique_id=subentry.subentry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = device_registry.async_get_device(
|
||||||
|
identifiers={(DOMAIN, entry.entry_id)}
|
||||||
|
)
|
||||||
|
if device is not None:
|
||||||
|
device_registry.async_update_device(
|
||||||
|
device.id,
|
||||||
|
new_identifiers={(DOMAIN, subentry.subentry_id)},
|
||||||
|
add_config_subentry_id=subentry.subentry_id,
|
||||||
|
add_config_entry_id=parent_entry.entry_id,
|
||||||
|
)
|
||||||
|
if parent_entry.entry_id != entry.entry_id:
|
||||||
|
device_registry.async_update_device(
|
||||||
|
device.id,
|
||||||
|
remove_config_entry_id=entry.entry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not use_existing:
|
||||||
|
await hass.config_entries.async_remove(entry.entry_id)
|
||||||
|
else:
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
entry,
|
||||||
|
options={},
|
||||||
|
version=2,
|
||||||
|
)
|
||||||
|
@ -14,12 +14,14 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.config_entries import (
|
from homeassistant.config_entries import (
|
||||||
ConfigEntry,
|
ConfigEntry,
|
||||||
|
ConfigEntryState,
|
||||||
ConfigFlow,
|
ConfigFlow,
|
||||||
ConfigFlowResult,
|
ConfigFlowResult,
|
||||||
OptionsFlow,
|
ConfigSubentryFlow,
|
||||||
|
SubentryFlowResult,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_URL
|
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.helpers.selector import (
|
from homeassistant.helpers.selector import (
|
||||||
BooleanSelector,
|
BooleanSelector,
|
||||||
@ -70,7 +72,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||||||
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||||
"""Handle a config flow for Ollama."""
|
"""Handle a config flow for Ollama."""
|
||||||
|
|
||||||
VERSION = 1
|
VERSION = 2
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize config flow."""
|
"""Initialize config flow."""
|
||||||
@ -94,6 +96,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
|
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
|
self._async_abort_entries_match({CONF_URL: self.url})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client = ollama.AsyncClient(
|
self.client = ollama.AsyncClient(
|
||||||
host=self.url, verify=get_default_context()
|
host=self.url, verify=get_default_context()
|
||||||
@ -146,8 +150,16 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
return await self.async_step_download()
|
return await self.async_step_download()
|
||||||
|
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=_get_title(self.model),
|
title=self.url,
|
||||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
||||||
|
subentries=[
|
||||||
|
{
|
||||||
|
"subentry_type": "conversation",
|
||||||
|
"data": {},
|
||||||
|
"title": _get_title(self.model),
|
||||||
|
"unique_id": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_download(
|
async def async_step_download(
|
||||||
@ -189,6 +201,14 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=_get_title(self.model),
|
title=_get_title(self.model),
|
||||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
||||||
|
subentries=[
|
||||||
|
{
|
||||||
|
"subentry_type": "conversation",
|
||||||
|
"data": {},
|
||||||
|
"title": _get_title(self.model),
|
||||||
|
"unique_id": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_failed(
|
async def async_step_failed(
|
||||||
@ -197,41 +217,62 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
"""Step after model downloading has failed."""
|
"""Step after model downloading has failed."""
|
||||||
return self.async_abort(reason="download_failed")
|
return self.async_abort(reason="download_failed")
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def async_get_options_flow(
|
@callback
|
||||||
config_entry: ConfigEntry,
|
def async_get_supported_subentry_types(
|
||||||
) -> OptionsFlow:
|
cls, config_entry: ConfigEntry
|
||||||
"""Create the options flow."""
|
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||||
return OllamaOptionsFlow(config_entry)
|
"""Return subentries supported by this integration."""
|
||||||
|
return {"conversation": ConversationSubentryFlowHandler}
|
||||||
|
|
||||||
|
|
||||||
class OllamaOptionsFlow(OptionsFlow):
|
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||||
"""Ollama options flow."""
|
"""Flow for managing conversation subentries."""
|
||||||
|
|
||||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
@property
|
||||||
"""Initialize options flow."""
|
def _is_new(self) -> bool:
|
||||||
self.url: str = config_entry.data[CONF_URL]
|
"""Return if this is a new subentry."""
|
||||||
self.model: str = config_entry.data[CONF_MODEL]
|
return self.source == "user"
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_set_options(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""Manage the options."""
|
"""Set conversation options."""
|
||||||
if user_input is not None:
|
# abort if entry is not loaded
|
||||||
|
if self._get_entry().state != ConfigEntryState.LOADED:
|
||||||
|
return self.async_abort(reason="entry_not_loaded")
|
||||||
|
|
||||||
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
|
if user_input is None:
|
||||||
|
if self._is_new:
|
||||||
|
options = {}
|
||||||
|
else:
|
||||||
|
options = self._get_reconfigure_subentry().data.copy()
|
||||||
|
|
||||||
|
elif self._is_new:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=_get_title(self.model), data=user_input
|
title=user_input.pop(CONF_NAME),
|
||||||
|
data=user_input,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.async_update_and_abort(
|
||||||
|
self._get_entry(),
|
||||||
|
self._get_reconfigure_subentry(),
|
||||||
|
data=user_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
options: Mapping[str, Any] = self.config_entry.options or {}
|
schema = ollama_config_option_schema(self.hass, self._is_new, options)
|
||||||
schema = ollama_config_option_schema(self.hass, options)
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="init",
|
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
||||||
data_schema=vol.Schema(schema),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async_step_user = async_step_set_options
|
||||||
|
async_step_reconfigure = async_step_set_options
|
||||||
|
|
||||||
|
|
||||||
def ollama_config_option_schema(
|
def ollama_config_option_schema(
|
||||||
hass: HomeAssistant, options: Mapping[str, Any]
|
hass: HomeAssistant, is_new: bool, options: Mapping[str, Any]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Ollama options schema."""
|
"""Ollama options schema."""
|
||||||
hass_apis: list[SelectOptionDict] = [
|
hass_apis: list[SelectOptionDict] = [
|
||||||
@ -242,54 +283,72 @@ def ollama_config_option_schema(
|
|||||||
for api in llm.async_get_apis(hass)
|
for api in llm.async_get_apis(hass)
|
||||||
]
|
]
|
||||||
|
|
||||||
return {
|
if is_new:
|
||||||
vol.Optional(
|
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||||
CONF_PROMPT,
|
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
|
||||||
description={
|
}
|
||||||
"suggested_value": options.get(
|
else:
|
||||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
schema = {}
|
||||||
|
|
||||||
|
schema.update(
|
||||||
|
{
|
||||||
|
vol.Optional(
|
||||||
|
CONF_PROMPT,
|
||||||
|
description={
|
||||||
|
"suggested_value": options.get(
|
||||||
|
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||||
|
)
|
||||||
|
},
|
||||||
|
): TemplateSelector(),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_LLM_HASS_API,
|
||||||
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||||
|
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
||||||
|
vol.Optional(
|
||||||
|
CONF_NUM_CTX,
|
||||||
|
description={
|
||||||
|
"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)
|
||||||
|
},
|
||||||
|
): NumberSelector(
|
||||||
|
NumberSelectorConfig(
|
||||||
|
min=MIN_NUM_CTX,
|
||||||
|
max=MAX_NUM_CTX,
|
||||||
|
step=1,
|
||||||
|
mode=NumberSelectorMode.BOX,
|
||||||
)
|
)
|
||||||
},
|
),
|
||||||
): TemplateSelector(),
|
vol.Optional(
|
||||||
vol.Optional(
|
CONF_MAX_HISTORY,
|
||||||
CONF_LLM_HASS_API,
|
description={
|
||||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
"suggested_value": options.get(
|
||||||
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY
|
||||||
vol.Optional(
|
)
|
||||||
CONF_NUM_CTX,
|
},
|
||||||
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
): NumberSelector(
|
||||||
): NumberSelector(
|
NumberSelectorConfig(
|
||||||
NumberSelectorConfig(
|
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
|
||||||
min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX
|
)
|
||||||
)
|
),
|
||||||
),
|
vol.Optional(
|
||||||
vol.Optional(
|
CONF_KEEP_ALIVE,
|
||||||
CONF_MAX_HISTORY,
|
description={
|
||||||
description={
|
"suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)
|
||||||
"suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)
|
},
|
||||||
},
|
): NumberSelector(
|
||||||
): NumberSelector(
|
NumberSelectorConfig(
|
||||||
NumberSelectorConfig(
|
min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
|
||||||
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
|
)
|
||||||
)
|
),
|
||||||
),
|
vol.Optional(
|
||||||
vol.Optional(
|
CONF_THINK,
|
||||||
CONF_KEEP_ALIVE,
|
description={
|
||||||
description={
|
"suggested_value": options.get("think", DEFAULT_THINK),
|
||||||
"suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)
|
},
|
||||||
},
|
): BooleanSelector(),
|
||||||
): NumberSelector(
|
}
|
||||||
NumberSelectorConfig(
|
)
|
||||||
min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
|
|
||||||
)
|
return schema
|
||||||
),
|
|
||||||
vol.Optional(
|
|
||||||
CONF_THINK,
|
|
||||||
description={
|
|
||||||
"suggested_value": options.get("think", DEFAULT_THINK),
|
|
||||||
},
|
|
||||||
): BooleanSelector(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_title(model: str) -> str:
|
def _get_title(model: str) -> str:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, AsyncIterator, Callable
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@ -11,13 +11,14 @@ import ollama
|
|||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent, llm
|
from homeassistant.helpers import device_registry as dr, intent, llm
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
|
from . import OllamaConfigEntry
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_KEEP_ALIVE,
|
CONF_KEEP_ALIVE,
|
||||||
CONF_MAX_HISTORY,
|
CONF_MAX_HISTORY,
|
||||||
@ -40,12 +41,18 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config_entry: ConfigEntry,
|
config_entry: OllamaConfigEntry,
|
||||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up conversation entities."""
|
"""Set up conversation entities."""
|
||||||
agent = OllamaConversationEntity(config_entry)
|
for subentry in config_entry.subentries.values():
|
||||||
async_add_entities([agent])
|
if subentry.subentry_type != "conversation":
|
||||||
|
continue
|
||||||
|
|
||||||
|
async_add_entities(
|
||||||
|
[OllamaConversationEntity(config_entry, subentry)],
|
||||||
|
config_subentry_id=subentry.subentry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(
|
def _format_tool(
|
||||||
@ -130,7 +137,7 @@ def _convert_content(
|
|||||||
|
|
||||||
|
|
||||||
async def _transform_stream(
|
async def _transform_stream(
|
||||||
result: AsyncGenerator[ollama.Message],
|
result: AsyncIterator[ollama.ChatResponse],
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
"""Transform the response stream into HA format.
|
"""Transform the response stream into HA format.
|
||||||
|
|
||||||
@ -174,17 +181,22 @@ class OllamaConversationEntity(
|
|||||||
):
|
):
|
||||||
"""Ollama conversation agent."""
|
"""Ollama conversation agent."""
|
||||||
|
|
||||||
_attr_has_entity_name = True
|
|
||||||
_attr_supports_streaming = True
|
_attr_supports_streaming = True
|
||||||
|
|
||||||
def __init__(self, entry: ConfigEntry) -> None:
|
def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
self.entry = entry
|
||||||
|
self.subentry = subentry
|
||||||
# conversation id -> message history
|
self._attr_name = subentry.title
|
||||||
self._attr_name = entry.title
|
self._attr_unique_id = subentry.subentry_id
|
||||||
self._attr_unique_id = entry.entry_id
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
identifiers={(DOMAIN, subentry.subentry_id)},
|
||||||
|
name=subentry.title,
|
||||||
|
manufacturer="Ollama",
|
||||||
|
model=entry.data[CONF_MODEL],
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||||
self._attr_supported_features = (
|
self._attr_supported_features = (
|
||||||
conversation.ConversationEntityFeature.CONTROL
|
conversation.ConversationEntityFeature.CONTROL
|
||||||
)
|
)
|
||||||
@ -216,7 +228,7 @@ class OllamaConversationEntity(
|
|||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Call the API."""
|
"""Call the API."""
|
||||||
settings = {**self.entry.data, **self.entry.options}
|
settings = {**self.entry.data, **self.subentry.data}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await chat_log.async_provide_llm_data(
|
await chat_log.async_provide_llm_data(
|
||||||
@ -248,9 +260,9 @@ class OllamaConversationEntity(
|
|||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
settings = {**self.entry.data, **self.entry.options}
|
settings = {**self.entry.data, **self.subentry.data}
|
||||||
|
|
||||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
client = self.entry.runtime_data
|
||||||
model = settings[CONF_MODEL]
|
model = settings[CONF_MODEL]
|
||||||
|
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
@ -12,7 +12,8 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
"download_failed": "Model downloading failed"
|
"download_failed": "Model downloading failed",
|
||||||
|
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
|
||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||||
@ -22,23 +23,35 @@
|
|||||||
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"options": {
|
"config_subentries": {
|
||||||
"step": {
|
"conversation": {
|
||||||
"init": {
|
"initiate_flow": {
|
||||||
"data": {
|
"user": "Add conversation agent",
|
||||||
"prompt": "Instructions",
|
"reconfigure": "Reconfigure conversation agent"
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
},
|
||||||
"max_history": "Max history messages",
|
"entry_type": "Conversation agent",
|
||||||
"num_ctx": "Context window size",
|
"step": {
|
||||||
"keep_alive": "Keep alive",
|
"set_options": {
|
||||||
"think": "Think before responding"
|
"data": {
|
||||||
},
|
"name": "[%key:common::config_flow::data::name%]",
|
||||||
"data_description": {
|
"prompt": "Instructions",
|
||||||
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||||
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.",
|
"max_history": "Max history messages",
|
||||||
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
|
"num_ctx": "Context window size",
|
||||||
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
|
"keep_alive": "Keep alive",
|
||||||
|
"think": "Think before responding"
|
||||||
|
},
|
||||||
|
"data_description": {
|
||||||
|
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||||
|
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.",
|
||||||
|
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
|
||||||
|
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
||||||
|
"entry_not_loaded": "Cannot add things while the configuration is disabled."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,15 @@ def mock_config_entry(
|
|||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain=ollama.DOMAIN,
|
domain=ollama.DOMAIN,
|
||||||
data=TEST_USER_DATA,
|
data=TEST_USER_DATA,
|
||||||
options=mock_config_entry_options,
|
version=2,
|
||||||
|
subentries_data=[
|
||||||
|
{
|
||||||
|
"data": mock_config_entry_options,
|
||||||
|
"subentry_type": "conversation",
|
||||||
|
"title": "Ollama Conversation",
|
||||||
|
"unique_id": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
return entry
|
return entry
|
||||||
@ -41,8 +49,10 @@ def mock_config_entry_with_assist(
|
|||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
"""Mock a config entry with assist."""
|
"""Mock a config entry with assist."""
|
||||||
hass.config_entries.async_update_entry(
|
hass.config_entries.async_update_subentry(
|
||||||
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
|
mock_config_entry,
|
||||||
|
next(iter(mock_config_entry.subentries.values())),
|
||||||
|
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
||||||
)
|
)
|
||||||
return mock_config_entry
|
return mock_config_entry
|
||||||
|
|
||||||
|
@ -63,6 +63,37 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_duplicate_entry(hass: HomeAssistant) -> None:
|
||||||
|
"""Test we abort on duplicate config entry."""
|
||||||
|
MockConfigEntry(
|
||||||
|
domain=ollama.DOMAIN,
|
||||||
|
data={
|
||||||
|
ollama.CONF_URL: "http://localhost:11434",
|
||||||
|
ollama.CONF_MODEL: "test_model",
|
||||||
|
},
|
||||||
|
).add_to_hass(hass)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert not result["errors"]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||||
|
return_value={"models": [{"model": "test_model"}]},
|
||||||
|
):
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
ollama.CONF_URL: "http://localhost:11434",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "already_configured"
|
||||||
|
|
||||||
|
|
||||||
async def test_form_need_download(hass: HomeAssistant) -> None:
|
async def test_form_need_download(hass: HomeAssistant) -> None:
|
||||||
"""Test flow when a model needs to be downloaded."""
|
"""Test flow when a model needs to be downloaded."""
|
||||||
# Pretend we already set up a config entry.
|
# Pretend we already set up a config entry.
|
||||||
@ -155,14 +186,21 @@ async def test_form_need_download(hass: HomeAssistant) -> None:
|
|||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_options(
|
async def test_subentry_options(
|
||||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the options form."""
|
"""Test the subentry options form."""
|
||||||
options_flow = await hass.config_entries.options.async_init(
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
mock_config_entry.entry_id
|
|
||||||
|
# Test reconfiguration
|
||||||
|
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||||
|
hass, subentry.subentry_id
|
||||||
)
|
)
|
||||||
options = await hass.config_entries.options.async_configure(
|
|
||||||
|
assert options_flow["type"] is FlowResultType.FORM
|
||||||
|
assert options_flow["step_id"] == "set_options"
|
||||||
|
|
||||||
|
options = await hass.config_entries.subentries.async_configure(
|
||||||
options_flow["flow_id"],
|
options_flow["flow_id"],
|
||||||
{
|
{
|
||||||
ollama.CONF_PROMPT: "test prompt",
|
ollama.CONF_PROMPT: "test prompt",
|
||||||
@ -172,8 +210,10 @@ async def test_options(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
|
||||||
assert options["data"] == {
|
assert options["type"] is FlowResultType.ABORT
|
||||||
|
assert options["reason"] == "reconfigure_successful"
|
||||||
|
assert subentry.data == {
|
||||||
ollama.CONF_PROMPT: "test prompt",
|
ollama.CONF_PROMPT: "test prompt",
|
||||||
ollama.CONF_MAX_HISTORY: 100,
|
ollama.CONF_MAX_HISTORY: 100,
|
||||||
ollama.CONF_NUM_CTX: 32768,
|
ollama.CONF_NUM_CTX: 32768,
|
||||||
@ -181,6 +221,22 @@ async def test_options(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_creating_conversation_subentry_not_loaded(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_init_component,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a conversation subentry when entry is not loaded."""
|
||||||
|
await hass.config_entries.async_unload(mock_config_entry.entry_id)
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "conversation"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "entry_not_loaded"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("side_effect", "error"),
|
("side_effect", "error"),
|
||||||
[
|
[
|
||||||
|
@ -35,7 +35,7 @@ async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]:
|
|||||||
yield msg
|
yield msg
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"])
|
@pytest.mark.parametrize("agent_id", [None, "conversation.ollama_conversation"])
|
||||||
async def test_chat(
|
async def test_chat(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
@ -149,9 +149,11 @@ async def test_template_variables(
|
|||||||
mock_user.id = "12345"
|
mock_user.id = "12345"
|
||||||
mock_user.name = "Test User"
|
mock_user.name = "Test User"
|
||||||
|
|
||||||
hass.config_entries.async_update_entry(
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
hass.config_entries.async_update_subentry(
|
||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
options={
|
subentry,
|
||||||
|
data={
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"The user name is {{ user_name }}. "
|
"The user name is {{ user_name }}. "
|
||||||
"The user id is {{ llm_context.context.user_id }}."
|
"The user id is {{ llm_context.context.user_id }}."
|
||||||
@ -382,10 +384,12 @@ async def test_unknown_hass_api(
|
|||||||
mock_init_component,
|
mock_init_component,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test when we reference an API that no longer exists."""
|
"""Test when we reference an API that no longer exists."""
|
||||||
hass.config_entries.async_update_entry(
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
hass.config_entries.async_update_subentry(
|
||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
options={
|
subentry,
|
||||||
**mock_config_entry.options,
|
data={
|
||||||
|
**subentry.data,
|
||||||
CONF_LLM_HASS_API: "non-existing",
|
CONF_LLM_HASS_API: "non-existing",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -518,8 +522,9 @@ async def test_message_history_unlimited(
|
|||||||
with (
|
with (
|
||||||
patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
|
patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
|
||||||
):
|
):
|
||||||
hass.config_entries.async_update_entry(
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
|
hass.config_entries.async_update_subentry(
|
||||||
|
mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0}
|
||||||
)
|
)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
@ -563,9 +568,11 @@ async def test_template_error(
|
|||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that template error handling works."""
|
"""Test that template error handling works."""
|
||||||
hass.config_entries.async_update_entry(
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
hass.config_entries.async_update_subentry(
|
||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
options={
|
subentry,
|
||||||
|
data={
|
||||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -593,7 +600,7 @@ async def test_conversation_agent(
|
|||||||
)
|
)
|
||||||
assert agent.supported_languages == MATCH_ALL
|
assert agent.supported_languages == MATCH_ALL
|
||||||
|
|
||||||
state = hass.states.get("conversation.mock_title")
|
state = hass.states.get("conversation.ollama_conversation")
|
||||||
assert state
|
assert state
|
||||||
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
|
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
|
||||||
|
|
||||||
@ -609,7 +616,7 @@ async def test_conversation_agent_with_assist(
|
|||||||
)
|
)
|
||||||
assert agent.supported_languages == MATCH_ALL
|
assert agent.supported_languages == MATCH_ALL
|
||||||
|
|
||||||
state = hass.states.get("conversation.mock_title")
|
state = hass.states.get("conversation.ollama_conversation")
|
||||||
assert state
|
assert state
|
||||||
assert (
|
assert (
|
||||||
state.attributes[ATTR_SUPPORTED_FEATURES]
|
state.attributes[ATTR_SUPPORTED_FEATURES]
|
||||||
@ -642,7 +649,7 @@ async def test_options(
|
|||||||
"test message",
|
"test message",
|
||||||
None,
|
None,
|
||||||
Context(),
|
Context(),
|
||||||
agent_id="conversation.mock_title",
|
agent_id="conversation.ollama_conversation",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert mock_chat.call_count == 1
|
assert mock_chat.call_count == 1
|
||||||
@ -667,9 +674,11 @@ async def test_reasoning_filter(
|
|||||||
entry = MockConfigEntry()
|
entry = MockConfigEntry()
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
|
|
||||||
hass.config_entries.async_update_entry(
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
hass.config_entries.async_update_subentry(
|
||||||
mock_config_entry,
|
mock_config_entry,
|
||||||
options={
|
subentry,
|
||||||
|
data={
|
||||||
ollama.CONF_THINK: think,
|
ollama.CONF_THINK: think,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -6,9 +6,13 @@ from httpx import ConnectError
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import ollama
|
from homeassistant.components import ollama
|
||||||
|
from homeassistant.components.ollama.const import DOMAIN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from . import TEST_OPTIONS, TEST_USER_DATA
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
@ -34,3 +38,250 @@ async def test_init_error(
|
|||||||
assert await async_setup_component(hass, ollama.DOMAIN, {})
|
assert await async_setup_component(hass, ollama.DOMAIN, {})
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert error in caplog.text
|
assert error in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migration_from_v1_to_v2(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test migration from version 1 to version 2."""
|
||||||
|
# Create a v1 config entry with conversation options and an entity
|
||||||
|
mock_config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
data=TEST_USER_DATA,
|
||||||
|
options=TEST_OPTIONS,
|
||||||
|
version=1,
|
||||||
|
title="llama-3.2-8b",
|
||||||
|
)
|
||||||
|
mock_config_entry.add_to_hass(hass)
|
||||||
|
|
||||||
|
device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=mock_config_entry.entry_id,
|
||||||
|
identifiers={(DOMAIN, mock_config_entry.entry_id)},
|
||||||
|
name=mock_config_entry.title,
|
||||||
|
manufacturer="Ollama",
|
||||||
|
model="Ollama",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
entity = entity_registry.async_get_or_create(
|
||||||
|
"conversation",
|
||||||
|
DOMAIN,
|
||||||
|
mock_config_entry.entry_id,
|
||||||
|
config_entry=mock_config_entry,
|
||||||
|
device_id=device.id,
|
||||||
|
suggested_object_id="llama_3_2_8b",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run migration
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.ollama.async_setup_entry",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
|
||||||
|
assert mock_config_entry.version == 2
|
||||||
|
assert mock_config_entry.data == TEST_USER_DATA
|
||||||
|
assert mock_config_entry.options == {}
|
||||||
|
|
||||||
|
assert len(mock_config_entry.subentries) == 1
|
||||||
|
|
||||||
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
|
assert subentry.unique_id is None
|
||||||
|
assert subentry.title == "llama-3.2-8b"
|
||||||
|
assert subentry.subentry_type == "conversation"
|
||||||
|
assert subentry.data == TEST_OPTIONS
|
||||||
|
|
||||||
|
migrated_entity = entity_registry.async_get(entity.entity_id)
|
||||||
|
assert migrated_entity is not None
|
||||||
|
assert migrated_entity.config_entry_id == mock_config_entry.entry_id
|
||||||
|
assert migrated_entity.config_subentry_id == subentry.subentry_id
|
||||||
|
assert migrated_entity.unique_id == subentry.subentry_id
|
||||||
|
|
||||||
|
# Check device migration
|
||||||
|
assert not device_registry.async_get_device(
|
||||||
|
identifiers={(DOMAIN, mock_config_entry.entry_id)}
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
migrated_device := device_registry.async_get_device(
|
||||||
|
identifiers={(DOMAIN, subentry.subentry_id)}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert migrated_device.identifiers == {(DOMAIN, subentry.subentry_id)}
|
||||||
|
assert migrated_device.id == device.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migration_from_v1_to_v2_with_multiple_urls(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test migration from version 1 to version 2 with different URLs."""
|
||||||
|
# Create two v1 config entries with different URLs
|
||||||
|
mock_config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
||||||
|
options=TEST_OPTIONS,
|
||||||
|
version=1,
|
||||||
|
title="Ollama 1",
|
||||||
|
)
|
||||||
|
mock_config_entry.add_to_hass(hass)
|
||||||
|
mock_config_entry_2 = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
|
||||||
|
options=TEST_OPTIONS,
|
||||||
|
version=1,
|
||||||
|
title="Ollama 2",
|
||||||
|
)
|
||||||
|
mock_config_entry_2.add_to_hass(hass)
|
||||||
|
|
||||||
|
device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=mock_config_entry.entry_id,
|
||||||
|
identifiers={(DOMAIN, mock_config_entry.entry_id)},
|
||||||
|
name=mock_config_entry.title,
|
||||||
|
manufacturer="Ollama",
|
||||||
|
model="Ollama 1",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"conversation",
|
||||||
|
DOMAIN,
|
||||||
|
mock_config_entry.entry_id,
|
||||||
|
config_entry=mock_config_entry,
|
||||||
|
device_id=device.id,
|
||||||
|
suggested_object_id="ollama_1",
|
||||||
|
)
|
||||||
|
|
||||||
|
device_2 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=mock_config_entry_2.entry_id,
|
||||||
|
identifiers={(DOMAIN, mock_config_entry_2.entry_id)},
|
||||||
|
name=mock_config_entry_2.title,
|
||||||
|
manufacturer="Ollama",
|
||||||
|
model="Ollama 2",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"conversation",
|
||||||
|
DOMAIN,
|
||||||
|
mock_config_entry_2.entry_id,
|
||||||
|
config_entry=mock_config_entry_2,
|
||||||
|
device_id=device_2.id,
|
||||||
|
suggested_object_id="ollama_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run migration
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.ollama.async_setup_entry",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 2
|
||||||
|
|
||||||
|
for idx, entry in enumerate(entries):
|
||||||
|
assert entry.version == 2
|
||||||
|
assert not entry.options
|
||||||
|
assert len(entry.subentries) == 1
|
||||||
|
subentry = list(entry.subentries.values())[0]
|
||||||
|
assert subentry.subentry_type == "conversation"
|
||||||
|
assert subentry.data == TEST_OPTIONS
|
||||||
|
assert subentry.title == f"Ollama {idx + 1}"
|
||||||
|
|
||||||
|
dev = device_registry.async_get_device(
|
||||||
|
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
||||||
|
)
|
||||||
|
assert dev is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_migration_from_v1_to_v2_with_same_urls(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test migration from version 1 to version 2 with same URLs consolidates entries."""
|
||||||
|
# Create two v1 config entries with the same URL
|
||||||
|
mock_config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
|
||||||
|
options=TEST_OPTIONS,
|
||||||
|
version=1,
|
||||||
|
title="Ollama",
|
||||||
|
)
|
||||||
|
mock_config_entry.add_to_hass(hass)
|
||||||
|
mock_config_entry_2 = MockConfigEntry(
|
||||||
|
domain=DOMAIN,
|
||||||
|
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
|
||||||
|
options=TEST_OPTIONS,
|
||||||
|
version=1,
|
||||||
|
title="Ollama 2",
|
||||||
|
)
|
||||||
|
mock_config_entry_2.add_to_hass(hass)
|
||||||
|
|
||||||
|
device = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=mock_config_entry.entry_id,
|
||||||
|
identifiers={(DOMAIN, mock_config_entry.entry_id)},
|
||||||
|
name=mock_config_entry.title,
|
||||||
|
manufacturer="Ollama",
|
||||||
|
model="Ollama",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"conversation",
|
||||||
|
DOMAIN,
|
||||||
|
mock_config_entry.entry_id,
|
||||||
|
config_entry=mock_config_entry,
|
||||||
|
device_id=device.id,
|
||||||
|
suggested_object_id="ollama",
|
||||||
|
)
|
||||||
|
|
||||||
|
device_2 = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=mock_config_entry_2.entry_id,
|
||||||
|
identifiers={(DOMAIN, mock_config_entry_2.entry_id)},
|
||||||
|
name=mock_config_entry_2.title,
|
||||||
|
manufacturer="Ollama",
|
||||||
|
model="Ollama",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
entity_registry.async_get_or_create(
|
||||||
|
"conversation",
|
||||||
|
DOMAIN,
|
||||||
|
mock_config_entry_2.entry_id,
|
||||||
|
config_entry=mock_config_entry_2,
|
||||||
|
device_id=device_2.id,
|
||||||
|
suggested_object_id="ollama_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run migration
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.ollama.async_setup_entry",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Should have only one entry left (consolidated)
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
|
||||||
|
entry = entries[0]
|
||||||
|
assert entry.version == 2
|
||||||
|
assert not entry.options
|
||||||
|
assert len(entry.subentries) == 2 # Two subentries from the two original entries
|
||||||
|
|
||||||
|
# Check both subentries exist with correct data
|
||||||
|
subentries = list(entry.subentries.values())
|
||||||
|
titles = [sub.title for sub in subentries]
|
||||||
|
assert "Ollama" in titles
|
||||||
|
assert "Ollama 2" in titles
|
||||||
|
|
||||||
|
for subentry in subentries:
|
||||||
|
assert subentry.subentry_type == "conversation"
|
||||||
|
assert subentry.data == TEST_OPTIONS
|
||||||
|
|
||||||
|
# Check devices were migrated correctly
|
||||||
|
dev = device_registry.async_get_device(
|
||||||
|
identifiers={(DOMAIN, subentry.subentry_id)}
|
||||||
|
)
|
||||||
|
assert dev is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user