diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index 6fe4720d13f..e16550c1e94 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -28,6 +28,7 @@ from .const import ( CONF_NUM_CTX, CONF_PROMPT, CONF_THINK, + DEFAULT_AI_TASK_NAME, DEFAULT_NAME, DEFAULT_TIMEOUT, DOMAIN, @@ -47,7 +48,7 @@ __all__ = [ ] CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) -PLATFORMS = (Platform.CONVERSATION,) +PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION) type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient] @@ -118,6 +119,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None: parent_entry = api_keys_entries[entry.data[CONF_URL]] hass.config_entries.async_add_subentry(parent_entry, subentry) + conversation_entity = entity_registry.async_get_entity_id( "conversation", DOMAIN, @@ -208,6 +210,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> minor_version=1, ) + if entry.version == 3 and entry.minor_version == 1: + # Add AI Task subentry with default options. We can only create a new + # subentry if we can find an existing model in the entry. The model + # was removed in the previous migration step, so we need to + # check the subentries for an existing model. + existing_model = next( + iter( + model + for subentry in entry.subentries.values() + if (model := subentry.data.get(CONF_MODEL)) is not None + ), + None, + ) + if existing_model: + hass.config_entries.async_add_subentry( + entry, + ConfigSubentry( + data=MappingProxyType({CONF_MODEL: existing_model}), + subentry_type="ai_task_data", + title=DEFAULT_AI_TASK_NAME, + unique_id=None, + ), + ) + hass.config_entries.async_update_entry(entry, minor_version=2) + _LOGGER.debug( "Migration to version %s:%s successful", entry.version, entry.minor_version ) diff --git a/homeassistant/components/ollama/ai_task.py b/homeassistant/components/ollama/ai_task.py new file mode 100644 index 00000000000..d796b28aac8 --- /dev/null +++ b/homeassistant/components/ollama/ai_task.py @@ -0,0 +1,77 @@ +"""AI Task integration for Ollama.""" + +from __future__ import annotations + +from json import JSONDecodeError +import logging + +from homeassistant.components import ai_task, conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.util.json import json_loads + +from .entity import OllamaBaseLLMEntity + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, +) -> None: + """Set up AI Task entities.""" + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "ai_task_data": + continue + + async_add_entities( + [OllamaTaskEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) + + +class OllamaTaskEntity( + ai_task.AITaskEntity, + OllamaBaseLLMEntity, +): + """Ollama AI Task entity.""" + + _attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA + + async def _async_generate_data( + self, + task: ai_task.GenDataTask, + chat_log: conversation.ChatLog, + ) -> ai_task.GenDataTaskResult: + """Handle a generate data task.""" + await self._async_handle_chat_log(chat_log, task.structure) + + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + raise HomeAssistantError( + "Last content in chat log is not an AssistantContent" + ) + + text = chat_log.content[-1].content or "" + + if not task.structure: + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=text, + ) + try: + data = json_loads(text) + except JSONDecodeError as err: + _LOGGER.error( + "Failed to parse JSON response: %s. Response: %s", + err, + text, + ) + raise HomeAssistantError("Error with Ollama structured response") from err + + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=data, + ) diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index 49eb12a5c23..cca917f6c29 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -46,6 +46,8 @@ from .const import ( CONF_NUM_CTX, CONF_PROMPT, CONF_THINK, + DEFAULT_AI_TASK_NAME, + DEFAULT_CONVERSATION_NAME, DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, DEFAULT_MODEL, @@ -74,7 +76,7 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for Ollama.""" VERSION = 3 - MINOR_VERSION = 1 + MINOR_VERSION = 2 def __init__(self) -> None: """Initialize config flow.""" @@ -136,11 +138,14 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): cls, config_entry: ConfigEntry ) -> dict[str, type[ConfigSubentryFlow]]: """Return subentries supported by this integration.""" - return {"conversation": ConversationSubentryFlowHandler} + return { + "conversation": OllamaSubentryFlowHandler, + "ai_task_data": OllamaSubentryFlowHandler, + } -class ConversationSubentryFlowHandler(ConfigSubentryFlow): - """Flow for managing conversation subentries.""" +class OllamaSubentryFlowHandler(ConfigSubentryFlow): + """Flow for managing Ollama subentries.""" def __init__(self) -> None: """Initialize the subentry flow.""" @@ -201,7 +206,11 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow): step_id="set_options", data_schema=vol.Schema( ollama_config_option_schema( - self.hass, self._is_new, options, models_to_list + self.hass, + self._is_new, + self._subentry_type, + options, + models_to_list, ) ), ) @@ -300,13 +309,19 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow): def ollama_config_option_schema( hass: HomeAssistant, is_new: bool, + subentry_type: str, options: Mapping[str, Any], models_to_list: list[SelectOptionDict], ) -> dict: """Ollama options schema.""" if is_new: + if subentry_type == "ai_task_data": + default_name = DEFAULT_AI_TASK_NAME + else: + default_name = DEFAULT_CONVERSATION_NAME + schema: dict = { - vol.Required(CONF_NAME, default="Ollama Conversation"): str, + vol.Required(CONF_NAME, default=default_name): str, } else: schema = {} @@ -319,29 +334,38 @@ def ollama_config_option_schema( ): SelectSelector( SelectSelectorConfig(options=models_to_list, custom_value=True) ), - vol.Optional( - CONF_PROMPT, - description={ - "suggested_value": options.get( - CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT - ) - }, - ): TemplateSelector(), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": options.get(CONF_LLM_HASS_API)}, - ): SelectSelector( - SelectSelectorConfig( - options=[ - SelectOptionDict( - label=api.name, - value=api.id, + } + ) + if subentry_type == "conversation": + schema.update( + { + vol.Optional( + CONF_PROMPT, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT ) - for api in llm.async_get_apis(hass) - ], - multiple=True, - ) - ), + }, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + ): SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ], + multiple=True, + ) + ), + } + ) + schema.update( + { vol.Optional( CONF_NUM_CTX, description={ diff --git a/homeassistant/components/ollama/const.py b/homeassistant/components/ollama/const.py index 3175525c70d..7e80570bd5e 100644 --- a/homeassistant/components/ollama/const.py +++ b/homeassistant/components/ollama/const.py @@ -159,3 +159,10 @@ MODEL_NAMES = [ # https://ollama.com/library "zephyr", ] DEFAULT_MODEL = "llama3.2:latest" + +DEFAULT_CONVERSATION_NAME = "Ollama Conversation" +DEFAULT_AI_TASK_NAME = "Ollama AI Task" + +RECOMMENDED_CONVERSATION_OPTIONS = { + CONF_MAX_HISTORY: DEFAULT_MAX_HISTORY, +} diff --git a/homeassistant/components/ollama/entity.py b/homeassistant/components/ollama/entity.py index 7b63b1dff00..4122d0c67d8 100644 --- a/homeassistant/components/ollama/entity.py +++ b/homeassistant/components/ollama/entity.py @@ -8,6 +8,7 @@ import logging from typing import Any import ollama +import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import conversation @@ -180,6 +181,7 @@ class OllamaBaseLLMEntity(Entity): async def _async_handle_chat_log( self, chat_log: conversation.ChatLog, + structure: vol.Schema | None = None, ) -> None: """Generate an answer for the chat log.""" settings = {**self.entry.data, **self.subentry.data} @@ -200,6 +202,17 @@ class OllamaBaseLLMEntity(Entity): max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) self._trim_history(message_history, max_messages) + output_format: dict[str, Any] | None = None + if structure: + output_format = convert( + structure, + custom_serializer=( + chat_log.llm_api.custom_serializer + if chat_log.llm_api + else llm.selector_serializer + ), + ) + # Get response # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): @@ -214,6 +227,7 @@ class OllamaBaseLLMEntity(Entity): keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, think=settings.get(CONF_THINK), + format=output_format, ) except (ollama.RequestError, ollama.ResponseError) as err: _LOGGER.error("Unexpected error talking to Ollama server: %s", err) diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index bb08e4a4684..4261b2286bf 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -55,6 +55,44 @@ "progress": { "download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details." } + }, + "ai_task_data": { + "initiate_flow": { + "user": "Add Generate data with AI service", + "reconfigure": "Reconfigure Generate data with AI service" + }, + "entry_type": "Generate data with AI service", + "step": { + "set_options": { + "data": { + "model": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::model%]", + "name": "[%key:common::config_flow::data::name%]", + "prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::prompt%]", + "max_history": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::max_history%]", + "num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::num_ctx%]", + "keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::keep_alive%]", + "think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::think%]" + }, + "data_description": { + "prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::prompt%]", + "keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::keep_alive%]", + "num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::num_ctx%]", + "think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::think%]" + } + }, + "download": { + "title": "[%key:component::ollama::config_subentries::conversation::step::download::title%]" + } + }, + "abort": { + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", + "entry_not_loaded": "[%key:component::ollama::config_subentries::conversation::abort::entry_not_loaded%]", + "download_failed": "[%key:component::ollama::config_subentries::conversation::abort::download_failed%]", + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" + }, + "progress": { + "download": "[%key:component::ollama::config_subentries::conversation::progress::download%]" + } } } } diff --git a/tests/components/ollama/__init__.py b/tests/components/ollama/__init__.py index 92db3b13304..9e7ae4772d4 100644 --- a/tests/components/ollama/__init__.py +++ b/tests/components/ollama/__init__.py @@ -12,3 +12,8 @@ TEST_OPTIONS = { ollama.CONF_MAX_HISTORY: 2, ollama.CONF_MODEL: "test_model:latest", } + +TEST_AI_TASK_OPTIONS = { + ollama.CONF_MAX_HISTORY: 2, + ollama.CONF_MODEL: "test_model:latest", +} diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index 552e7dee20a..f3406bf5566 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import llm from homeassistant.setup import async_setup_component -from . import TEST_OPTIONS, TEST_USER_DATA +from . import TEST_AI_TASK_OPTIONS, TEST_OPTIONS, TEST_USER_DATA from tests.common import MockConfigEntry @@ -31,14 +31,20 @@ def mock_config_entry( domain=ollama.DOMAIN, data=TEST_USER_DATA, version=3, - minor_version=1, + minor_version=2, subentries_data=[ { "data": {**TEST_OPTIONS, **mock_config_entry_options}, "subentry_type": "conversation", "title": "Ollama Conversation", "unique_id": None, - } + }, + { + "data": TEST_AI_TASK_OPTIONS, + "subentry_type": "ai_task_data", + "title": "Ollama AI Task", + "unique_id": None, + }, ], ) entry.add_to_hass(hass) diff --git a/tests/components/ollama/test_ai_task.py b/tests/components/ollama/test_ai_task.py new file mode 100644 index 00000000000..ee812e7b316 --- /dev/null +++ b/tests/components/ollama/test_ai_task.py @@ -0,0 +1,245 @@ +"""Test AI Task platform of Ollama integration.""" + +from unittest.mock import patch + +import pytest +import voluptuous as vol + +from homeassistant.components import ai_task +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import entity_registry as er, selector + +from tests.common import MockConfigEntry + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation.""" + entity_id = "ai_task.ollama_ai_task" + + # Ensure entity is linked to the subentry + entity_entry = entity_registry.async_get(entity_id) + ai_task_entry = next( + iter( + entry + for entry in mock_config_entry.subentries.values() + if entry.subentry_type == "ai_task_data" + ) + ) + assert entity_entry is not None + assert entity_entry.config_entry_id == mock_config_entry.entry_id + assert entity_entry.config_subentry_id == ai_task_entry.subentry_id + + # Mock the Ollama chat response as an async iterator + async def mock_chat_response(): + """Mock streaming response.""" + yield { + "message": {"role": "assistant", "content": "Generated test data"}, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_chat_response(), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Generate test data", + ) + + assert result.data == "Generated test data" + + +@pytest.mark.usefixtures("mock_init_component") +async def test_run_task_with_streaming( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation with streaming response.""" + entity_id = "ai_task.ollama_ai_task" + + async def mock_stream(): + """Mock streaming response.""" + yield {"message": {"role": "assistant", "content": "Stream "}} + yield { + "message": {"role": "assistant", "content": "response"}, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_stream(), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Streaming Task", + entity_id=entity_id, + instructions="Generate streaming data", + ) + + assert result.data == "Stream response" + + +@pytest.mark.usefixtures("mock_init_component") +async def test_run_task_connection_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task with connection error.""" + entity_id = "ai_task.ollama_ai_task" + + with ( + patch( + "ollama.AsyncClient.chat", + side_effect=Exception("Connection failed"), + ), + pytest.raises(Exception, match="Connection failed"), + ): + await ai_task.async_generate_data( + hass, + task_name="Test Error Task", + entity_id=entity_id, + instructions="Generate data that will fail", + ) + + +@pytest.mark.usefixtures("mock_init_component") +async def test_run_task_empty_response( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task with empty response.""" + entity_id = "ai_task.ollama_ai_task" + + # Mock response with space (minimally non-empty) + async def mock_minimal_response(): + """Mock minimal streaming response.""" + yield { + "message": {"role": "assistant", "content": " "}, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_minimal_response(), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Minimal Task", + entity_id=entity_id, + instructions="Generate minimal data", + ) + + assert result.data == " " + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_structured_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation.""" + entity_id = "ai_task.ollama_ai_task" + + # Mock the Ollama chat response as an async iterator + async def mock_chat_response(): + """Mock streaming response.""" + yield { + "message": { + "role": "assistant", + "content": '{"characters": ["Mario", "Luigi"]}', + }, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_chat_response(), + ) as mock_chat: + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Generate test data", + structure=vol.Schema( + { + vol.Required("characters"): selector.selector( + { + "text": { + "multiple": True, + } + } + ) + }, + ), + ) + assert result.data == {"characters": ["Mario", "Luigi"]} + + assert mock_chat.call_count == 1 + assert mock_chat.call_args[1]["format"] == { + "type": "object", + "properties": {"characters": {"items": {"type": "string"}, "type": "array"}}, + "required": ["characters"], + } + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_invalid_structured_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation.""" + entity_id = "ai_task.ollama_ai_task" + + # Mock the Ollama chat response as an async iterator + async def mock_chat_response(): + """Mock streaming response.""" + yield { + "message": { + "role": "assistant", + "content": "INVALID JSON RESPONSE", + }, + "done": True, + "done_reason": "stop", + } + + with ( + patch( + "ollama.AsyncClient.chat", + return_value=mock_chat_response(), + ), + pytest.raises(HomeAssistantError), + ): + await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Generate test data", + structure=vol.Schema( + { + vol.Required("characters"): selector.selector( + { + "text": { + "multiple": True, + } + } + ) + }, + ), + ) diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py index 7372a460c95..1a873c2adb7 100644 --- a/tests/components/ollama/test_config_flow.py +++ b/tests/components/ollama/test_config_flow.py @@ -461,3 +461,78 @@ async def test_subentry_reconfigure_with_download( ollama.CONF_NUM_CTX: 8192.0, ollama.CONF_THINK: True, } + + +async def test_creating_ai_task_subentry( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test creating an AI task subentry.""" + old_subentries = set(mock_config_entry.subentries) + # Original conversation + original ai_task + assert len(mock_config_entry.subentries) == 2 + + with patch( + "ollama.AsyncClient.list", + return_value={"models": [{"model": "test_model:latest"}]}, + ): + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "ai_task_data"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result.get("type") is FlowResultType.FORM + assert result.get("step_id") == "set_options" + assert not result.get("errors") + + with patch( + "ollama.AsyncClient.list", + return_value={"models": [{"model": "test_model:latest"}]}, + ): + result2 = await hass.config_entries.subentries.async_configure( + result["flow_id"], + { + "name": "Custom AI Task", + ollama.CONF_MODEL: "test_model:latest", + ollama.CONF_MAX_HISTORY: 5, + ollama.CONF_NUM_CTX: 4096, + ollama.CONF_KEEP_ALIVE: 30, + ollama.CONF_THINK: False, + }, + ) + await hass.async_block_till_done() + + assert result2.get("type") is FlowResultType.CREATE_ENTRY + assert result2.get("title") == "Custom AI Task" + assert result2.get("data") == { + ollama.CONF_MODEL: "test_model:latest", + ollama.CONF_MAX_HISTORY: 5, + ollama.CONF_NUM_CTX: 4096, + ollama.CONF_KEEP_ALIVE: 30, + ollama.CONF_THINK: False, + } + + assert ( + len(mock_config_entry.subentries) == 3 + ) # Original conversation + original ai_task + new ai_task + + new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0] + new_subentry = mock_config_entry.subentries[new_subentry_id] + assert new_subentry.subentry_type == "ai_task_data" + assert new_subentry.title == "Custom AI Task" + + +async def test_ai_task_subentry_not_loaded( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating an AI task subentry when entry is not loaded.""" + # Don't call mock_init_component to simulate not loaded state + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "ai_task_data"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result.get("type") is FlowResultType.ABORT + assert result.get("reason") == "entry_not_loaded" diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index c7cd78fca9a..1db57302704 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -93,16 +93,23 @@ async def test_migration_from_v1( return_value=True, ): await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() assert mock_config_entry.version == 3 - assert mock_config_entry.minor_version == 1 + assert mock_config_entry.minor_version == 2 # After migration, parent entry should only have URL assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"} assert mock_config_entry.options == {} - assert len(mock_config_entry.subentries) == 1 + assert len(mock_config_entry.subentries) == 2 - subentry = next(iter(mock_config_entry.subentries.values())) + subentry = next( + iter( + entry + for entry in mock_config_entry.subentries.values() + if entry.subentry_type == "conversation" + ) + ) assert subentry.unique_id is None assert subentry.title == "llama-3.2-8b" assert subentry.subentry_type == "conversation" @@ -110,6 +117,18 @@ async def test_migration_from_v1( expected_subentry_data = TEST_OPTIONS.copy() assert subentry.data == expected_subentry_data + # Find the AI Task subentry + ai_task_subentry = next( + iter( + entry + for entry in mock_config_entry.subentries.values() + if entry.subentry_type == "ai_task_data" + ) + ) + assert ai_task_subentry.unique_id is None + assert ai_task_subentry.title == "Ollama AI Task" + assert ai_task_subentry.subentry_type == "ai_task_data" + migrated_entity = entity_registry.async_get(entity.entity_id) assert migrated_entity is not None assert migrated_entity.config_entry_id == mock_config_entry.entry_id @@ -204,10 +223,17 @@ async def test_migration_from_v1_with_multiple_urls( for idx, entry in enumerate(entries): assert entry.version == 3 - assert entry.minor_version == 1 + assert entry.minor_version == 2 assert not entry.options - assert len(entry.subentries) == 1 - subentry = list(entry.subentries.values())[0] + assert len(entry.subentries) == 2 + + subentry = next( + iter( + subentry + for subentry in entry.subentries.values() + if subentry.subentry_type == "conversation" + ) + ) assert subentry.subentry_type == "conversation" # Subentry should include the model along with the original options expected_subentry_data = TEST_OPTIONS.copy() @@ -215,6 +241,17 @@ async def test_migration_from_v1_with_multiple_urls( assert subentry.data == expected_subentry_data assert subentry.title == f"Ollama {idx + 1}" + # Find the AI Task subentry + ai_task_subentry = next( + iter( + subentry + for subentry in entry.subentries.values() + if subentry.subentry_type == "ai_task_data" + ) + ) + assert ai_task_subentry.subentry_type == "ai_task_data" + assert ai_task_subentry.title == "Ollama AI Task" + dev = device_registry.async_get_device( identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} ) @@ -295,9 +332,10 @@ async def test_migration_from_v1_with_same_urls( entry = entries[0] assert entry.version == 3 - assert entry.minor_version == 1 + assert entry.minor_version == 2 assert not entry.options - assert len(entry.subentries) == 2 # Two subentries from the two original entries + # Two conversation subentries from the two original entries and 1 aitask subentry + assert len(entry.subentries) == 3 # Check both subentries exist with correct data subentries = list(entry.subentries.values()) @@ -305,7 +343,11 @@ async def test_migration_from_v1_with_same_urls( assert "Ollama" in titles assert "Ollama 2" in titles - for subentry in subentries: + conversation_subentries = [ + subentry for subentry in subentries if subentry.subentry_type == "conversation" + ] + assert len(conversation_subentries) == 2 + for subentry in conversation_subentries: assert subentry.subentry_type == "conversation" # Subentry should include the model along with the original options expected_subentry_data = TEST_OPTIONS.copy() @@ -415,10 +457,10 @@ async def test_migration_from_v2_1( assert len(entries) == 1 entry = entries[0] assert entry.version == 3 - assert entry.minor_version == 1 + assert entry.minor_version == 2 assert not entry.options assert entry.title == "Ollama" - assert len(entry.subentries) == 2 + assert len(entry.subentries) == 3 conversation_subentries = [ subentry for subentry in entry.subentries.values() @@ -504,14 +546,44 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None: # Check migration to v3.1 assert mock_config_entry.version == 3 - assert mock_config_entry.minor_version == 1 + assert mock_config_entry.minor_version == 2 # Check that model was moved from main data to subentry assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"} - assert len(mock_config_entry.subentries) == 1 + assert len(mock_config_entry.subentries) == 2 subentry = next(iter(mock_config_entry.subentries.values())) assert subentry.data == { **V21_TEST_USER_DATA, ollama.CONF_MODEL: "test_model:latest", } + + +async def test_migration_from_v3_1_without_subentry(hass: HomeAssistant) -> None: + """Test migration from version 3.1 where there is no existing subentry. + + This exercises the code path where the model is not moved to a subentry + because the subentry does not exist, which is a scenario that can happen + if the user created the config entry without adding a subentry, or + if the user manually removed the subentry after the migration to v3.1. + """ + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={ + ollama.CONF_MODEL: "test_model:latest", + }, + version=3, + minor_version=1, + ) + mock_config_entry.add_to_hass(hass) + + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + + assert mock_config_entry.version == 3 + assert mock_config_entry.minor_version == 2 + + assert next(iter(mock_config_entry.subentries.values()), None) is None