mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Add AI Task support in Ollama (#148226)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
8cb9cadce9
commit
4b5c04b2f0
@ -28,6 +28,7 @@ from .const import (
|
||||
CONF_NUM_CTX,
|
||||
CONF_PROMPT,
|
||||
CONF_THINK,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_NAME,
|
||||
DEFAULT_TIMEOUT,
|
||||
DOMAIN,
|
||||
@ -47,7 +48,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (Platform.CONVERSATION,)
|
||||
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
|
||||
|
||||
type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient]
|
||||
|
||||
@ -118,6 +119,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
||||
parent_entry = api_keys_entries[entry.data[CONF_URL]]
|
||||
|
||||
hass.config_entries.async_add_subentry(parent_entry, subentry)
|
||||
|
||||
conversation_entity = entity_registry.async_get_entity_id(
|
||||
"conversation",
|
||||
DOMAIN,
|
||||
@ -208,6 +210,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
|
||||
minor_version=1,
|
||||
)
|
||||
|
||||
if entry.version == 3 and entry.minor_version == 1:
|
||||
# Add AI Task subentry with default options. We can only create a new
|
||||
# subentry if we can find an existing model in the entry. The model
|
||||
# was removed in the previous migration step, so we need to
|
||||
# check the subentries for an existing model.
|
||||
existing_model = next(
|
||||
iter(
|
||||
model
|
||||
for subentry in entry.subentries.values()
|
||||
if (model := subentry.data.get(CONF_MODEL)) is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_model:
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
ConfigSubentry(
|
||||
data=MappingProxyType({CONF_MODEL: existing_model}),
|
||||
subentry_type="ai_task_data",
|
||||
title=DEFAULT_AI_TASK_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Migration to version %s:%s successful", entry.version, entry.minor_version
|
||||
)
|
||||
|
77
homeassistant/components/ollama/ai_task.py
Normal file
77
homeassistant/components/ollama/ai_task.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""AI Task integration for Ollama."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .entity import OllamaBaseLLMEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AI Task entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "ai_task_data":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[OllamaTaskEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class OllamaTaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
OllamaBaseLLMEntity,
|
||||
):
|
||||
"""Ollama AI Task entity."""
|
||||
|
||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: ai_task.GenDataTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenDataTaskResult:
|
||||
"""Handle a generate data task."""
|
||||
await self._async_handle_chat_log(chat_log, task.structure)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
raise HomeAssistantError(
|
||||
"Last content in chat log is not an AssistantContent"
|
||||
)
|
||||
|
||||
text = chat_log.content[-1].content or ""
|
||||
|
||||
if not task.structure:
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=text,
|
||||
)
|
||||
try:
|
||||
data = json_loads(text)
|
||||
except JSONDecodeError as err:
|
||||
_LOGGER.error(
|
||||
"Failed to parse JSON response: %s. Response: %s",
|
||||
err,
|
||||
text,
|
||||
)
|
||||
raise HomeAssistantError("Error with Ollama structured response") from err
|
||||
|
||||
return ai_task.GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
)
|
@ -46,6 +46,8 @@ from .const import (
|
||||
CONF_NUM_CTX,
|
||||
CONF_PROMPT,
|
||||
CONF_THINK,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEFAULT_KEEP_ALIVE,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_MODEL,
|
||||
@ -74,7 +76,7 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Ollama."""
|
||||
|
||||
VERSION = 3
|
||||
MINOR_VERSION = 1
|
||||
MINOR_VERSION = 2
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize config flow."""
|
||||
@ -136,11 +138,14 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
cls, config_entry: ConfigEntry
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {"conversation": ConversationSubentryFlowHandler}
|
||||
return {
|
||||
"conversation": OllamaSubentryFlowHandler,
|
||||
"ai_task_data": OllamaSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing conversation subentries."""
|
||||
class OllamaSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing Ollama subentries."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the subentry flow."""
|
||||
@ -201,7 +206,11 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
step_id="set_options",
|
||||
data_schema=vol.Schema(
|
||||
ollama_config_option_schema(
|
||||
self.hass, self._is_new, options, models_to_list
|
||||
self.hass,
|
||||
self._is_new,
|
||||
self._subentry_type,
|
||||
options,
|
||||
models_to_list,
|
||||
)
|
||||
),
|
||||
)
|
||||
@ -300,13 +309,19 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
def ollama_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
is_new: bool,
|
||||
subentry_type: str,
|
||||
options: Mapping[str, Any],
|
||||
models_to_list: list[SelectOptionDict],
|
||||
) -> dict:
|
||||
"""Ollama options schema."""
|
||||
if is_new:
|
||||
if subentry_type == "ai_task_data":
|
||||
default_name = DEFAULT_AI_TASK_NAME
|
||||
else:
|
||||
default_name = DEFAULT_CONVERSATION_NAME
|
||||
|
||||
schema: dict = {
|
||||
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
|
||||
vol.Required(CONF_NAME, default=default_name): str,
|
||||
}
|
||||
else:
|
||||
schema = {}
|
||||
@ -319,29 +334,38 @@ def ollama_config_option_schema(
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=models_to_list, custom_value=True)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
}
|
||||
)
|
||||
if subentry_type == "conversation":
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
],
|
||||
multiple=True,
|
||||
)
|
||||
),
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(
|
||||
options=[
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
)
|
||||
for api in llm.async_get_apis(hass)
|
||||
],
|
||||
multiple=True,
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_NUM_CTX,
|
||||
description={
|
||||
|
@ -159,3 +159,10 @@ MODEL_NAMES = [ # https://ollama.com/library
|
||||
"zephyr",
|
||||
]
|
||||
DEFAULT_MODEL = "llama3.2:latest"
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "Ollama Conversation"
|
||||
DEFAULT_AI_TASK_NAME = "Ollama AI Task"
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_MAX_HISTORY: DEFAULT_MAX_HISTORY,
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import ollama
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
@ -180,6 +181,7 @@ class OllamaBaseLLMEntity(Entity):
|
||||
async def _async_handle_chat_log(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
structure: vol.Schema | None = None,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
settings = {**self.entry.data, **self.subentry.data}
|
||||
@ -200,6 +202,17 @@ class OllamaBaseLLMEntity(Entity):
|
||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||
self._trim_history(message_history, max_messages)
|
||||
|
||||
output_format: dict[str, Any] | None = None
|
||||
if structure:
|
||||
output_format = convert(
|
||||
structure,
|
||||
custom_serializer=(
|
||||
chat_log.llm_api.custom_serializer
|
||||
if chat_log.llm_api
|
||||
else llm.selector_serializer
|
||||
),
|
||||
)
|
||||
|
||||
# Get response
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
@ -214,6 +227,7 @@ class OllamaBaseLLMEntity(Entity):
|
||||
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
|
||||
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
|
||||
think=settings.get(CONF_THINK),
|
||||
format=output_format,
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
|
@ -55,6 +55,44 @@
|
||||
"progress": {
|
||||
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
||||
}
|
||||
},
|
||||
"ai_task_data": {
|
||||
"initiate_flow": {
|
||||
"user": "Add Generate data with AI service",
|
||||
"reconfigure": "Reconfigure Generate data with AI service"
|
||||
},
|
||||
"entry_type": "Generate data with AI service",
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
"model": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::model%]",
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::prompt%]",
|
||||
"max_history": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::max_history%]",
|
||||
"num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::num_ctx%]",
|
||||
"keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::keep_alive%]",
|
||||
"think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::think%]"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::prompt%]",
|
||||
"keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::keep_alive%]",
|
||||
"num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::num_ctx%]",
|
||||
"think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::think%]"
|
||||
}
|
||||
},
|
||||
"download": {
|
||||
"title": "[%key:component::ollama::config_subentries::conversation::step::download::title%]"
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
||||
"entry_not_loaded": "[%key:component::ollama::config_subentries::conversation::abort::entry_not_loaded%]",
|
||||
"download_failed": "[%key:component::ollama::config_subentries::conversation::abort::download_failed%]",
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
|
||||
},
|
||||
"progress": {
|
||||
"download": "[%key:component::ollama::config_subentries::conversation::progress::download%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12,3 +12,8 @@ TEST_OPTIONS = {
|
||||
ollama.CONF_MAX_HISTORY: 2,
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
}
|
||||
|
||||
TEST_AI_TASK_OPTIONS = {
|
||||
ollama.CONF_MAX_HISTORY: 2,
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import TEST_OPTIONS, TEST_USER_DATA
|
||||
from . import TEST_AI_TASK_OPTIONS, TEST_OPTIONS, TEST_USER_DATA
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -31,14 +31,20 @@ def mock_config_entry(
|
||||
domain=ollama.DOMAIN,
|
||||
data=TEST_USER_DATA,
|
||||
version=3,
|
||||
minor_version=1,
|
||||
minor_version=2,
|
||||
subentries_data=[
|
||||
{
|
||||
"data": {**TEST_OPTIONS, **mock_config_entry_options},
|
||||
"subentry_type": "conversation",
|
||||
"title": "Ollama Conversation",
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": TEST_AI_TASK_OPTIONS,
|
||||
"subentry_type": "ai_task_data",
|
||||
"title": "Ollama AI Task",
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
245
tests/components/ollama/test_ai_task.py
Normal file
245
tests/components/ollama/test_ai_task.py
Normal file
@ -0,0 +1,245 @@
|
||||
"""Test AI Task platform of Ollama integration."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ai_task
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity_registry as er, selector
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_data(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task data generation."""
|
||||
entity_id = "ai_task.ollama_ai_task"
|
||||
|
||||
# Ensure entity is linked to the subentry
|
||||
entity_entry = entity_registry.async_get(entity_id)
|
||||
ai_task_entry = next(
|
||||
iter(
|
||||
entry
|
||||
for entry in mock_config_entry.subentries.values()
|
||||
if entry.subentry_type == "ai_task_data"
|
||||
)
|
||||
)
|
||||
assert entity_entry is not None
|
||||
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
||||
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
||||
|
||||
# Mock the Ollama chat response as an async iterator
|
||||
async def mock_chat_response():
|
||||
"""Mock streaming response."""
|
||||
yield {
|
||||
"message": {"role": "assistant", "content": "Generated test data"},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value=mock_chat_response(),
|
||||
):
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Generate test data",
|
||||
)
|
||||
|
||||
assert result.data == "Generated test data"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_run_task_with_streaming(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task data generation with streaming response."""
|
||||
entity_id = "ai_task.ollama_ai_task"
|
||||
|
||||
async def mock_stream():
|
||||
"""Mock streaming response."""
|
||||
yield {"message": {"role": "assistant", "content": "Stream "}}
|
||||
yield {
|
||||
"message": {"role": "assistant", "content": "response"},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value=mock_stream(),
|
||||
):
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Streaming Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Generate streaming data",
|
||||
)
|
||||
|
||||
assert result.data == "Stream response"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_run_task_connection_error(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task with connection error."""
|
||||
entity_id = "ai_task.ollama_ai_task"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=Exception("Connection failed"),
|
||||
),
|
||||
pytest.raises(Exception, match="Connection failed"),
|
||||
):
|
||||
await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Error Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Generate data that will fail",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_run_task_empty_response(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task with empty response."""
|
||||
entity_id = "ai_task.ollama_ai_task"
|
||||
|
||||
# Mock response with space (minimally non-empty)
|
||||
async def mock_minimal_response():
|
||||
"""Mock minimal streaming response."""
|
||||
yield {
|
||||
"message": {"role": "assistant", "content": " "},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value=mock_minimal_response(),
|
||||
):
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Minimal Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Generate minimal data",
|
||||
)
|
||||
|
||||
assert result.data == " "
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_structured_data(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task data generation."""
|
||||
entity_id = "ai_task.ollama_ai_task"
|
||||
|
||||
# Mock the Ollama chat response as an async iterator
|
||||
async def mock_chat_response():
|
||||
"""Mock streaming response."""
|
||||
yield {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": '{"characters": ["Mario", "Luigi"]}',
|
||||
},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value=mock_chat_response(),
|
||||
) as mock_chat:
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Generate test data",
|
||||
structure=vol.Schema(
|
||||
{
|
||||
vol.Required("characters"): selector.selector(
|
||||
{
|
||||
"text": {
|
||||
"multiple": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||
|
||||
assert mock_chat.call_count == 1
|
||||
assert mock_chat.call_args[1]["format"] == {
|
||||
"type": "object",
|
||||
"properties": {"characters": {"items": {"type": "string"}, "type": "array"}},
|
||||
"required": ["characters"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_invalid_structured_data(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task data generation."""
|
||||
entity_id = "ai_task.ollama_ai_task"
|
||||
|
||||
# Mock the Ollama chat response as an async iterator
|
||||
async def mock_chat_response():
|
||||
"""Mock streaming response."""
|
||||
yield {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "INVALID JSON RESPONSE",
|
||||
},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value=mock_chat_response(),
|
||||
),
|
||||
pytest.raises(HomeAssistantError),
|
||||
):
|
||||
await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Generate test data",
|
||||
structure=vol.Schema(
|
||||
{
|
||||
vol.Required("characters"): selector.selector(
|
||||
{
|
||||
"text": {
|
||||
"multiple": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
@ -461,3 +461,78 @@ async def test_subentry_reconfigure_with_download(
|
||||
ollama.CONF_NUM_CTX: 8192.0,
|
||||
ollama.CONF_THINK: True,
|
||||
}
|
||||
|
||||
|
||||
async def test_creating_ai_task_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test creating an AI task subentry."""
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
# Original conversation + original ai_task
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": "test_model:latest"}]},
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result.get("type") is FlowResultType.FORM
|
||||
assert result.get("step_id") == "set_options"
|
||||
assert not result.get("errors")
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": "test_model:latest"}]},
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
"name": "Custom AI Task",
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
ollama.CONF_MAX_HISTORY: 5,
|
||||
ollama.CONF_NUM_CTX: 4096,
|
||||
ollama.CONF_KEEP_ALIVE: 30,
|
||||
ollama.CONF_THINK: False,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2.get("type") is FlowResultType.CREATE_ENTRY
|
||||
assert result2.get("title") == "Custom AI Task"
|
||||
assert result2.get("data") == {
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
ollama.CONF_MAX_HISTORY: 5,
|
||||
ollama.CONF_NUM_CTX: 4096,
|
||||
ollama.CONF_KEEP_ALIVE: 30,
|
||||
ollama.CONF_THINK: False,
|
||||
}
|
||||
|
||||
assert (
|
||||
len(mock_config_entry.subentries) == 3
|
||||
) # Original conversation + original ai_task + new ai_task
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
assert new_subentry.subentry_type == "ai_task_data"
|
||||
assert new_subentry.title == "Custom AI Task"
|
||||
|
||||
|
||||
async def test_ai_task_subentry_not_loaded(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating an AI task subentry when entry is not loaded."""
|
||||
# Don't call mock_init_component to simulate not loaded state
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result.get("type") is FlowResultType.ABORT
|
||||
assert result.get("reason") == "entry_not_loaded"
|
||||
|
@ -93,16 +93,23 @@ async def test_migration_from_v1(
|
||||
return_value=True,
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_config_entry.version == 3
|
||||
assert mock_config_entry.minor_version == 1
|
||||
assert mock_config_entry.minor_version == 2
|
||||
# After migration, parent entry should only have URL
|
||||
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
assert mock_config_entry.options == {}
|
||||
|
||||
assert len(mock_config_entry.subentries) == 1
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
subentry = next(
|
||||
iter(
|
||||
entry
|
||||
for entry in mock_config_entry.subentries.values()
|
||||
if entry.subentry_type == "conversation"
|
||||
)
|
||||
)
|
||||
assert subentry.unique_id is None
|
||||
assert subentry.title == "llama-3.2-8b"
|
||||
assert subentry.subentry_type == "conversation"
|
||||
@ -110,6 +117,18 @@ async def test_migration_from_v1(
|
||||
expected_subentry_data = TEST_OPTIONS.copy()
|
||||
assert subentry.data == expected_subentry_data
|
||||
|
||||
# Find the AI Task subentry
|
||||
ai_task_subentry = next(
|
||||
iter(
|
||||
entry
|
||||
for entry in mock_config_entry.subentries.values()
|
||||
if entry.subentry_type == "ai_task_data"
|
||||
)
|
||||
)
|
||||
assert ai_task_subentry.unique_id is None
|
||||
assert ai_task_subentry.title == "Ollama AI Task"
|
||||
assert ai_task_subentry.subentry_type == "ai_task_data"
|
||||
|
||||
migrated_entity = entity_registry.async_get(entity.entity_id)
|
||||
assert migrated_entity is not None
|
||||
assert migrated_entity.config_entry_id == mock_config_entry.entry_id
|
||||
@ -204,10 +223,17 @@ async def test_migration_from_v1_with_multiple_urls(
|
||||
|
||||
for idx, entry in enumerate(entries):
|
||||
assert entry.version == 3
|
||||
assert entry.minor_version == 1
|
||||
assert entry.minor_version == 2
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 1
|
||||
subentry = list(entry.subentries.values())[0]
|
||||
assert len(entry.subentries) == 2
|
||||
|
||||
subentry = next(
|
||||
iter(
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "conversation"
|
||||
)
|
||||
)
|
||||
assert subentry.subentry_type == "conversation"
|
||||
# Subentry should include the model along with the original options
|
||||
expected_subentry_data = TEST_OPTIONS.copy()
|
||||
@ -215,6 +241,17 @@ async def test_migration_from_v1_with_multiple_urls(
|
||||
assert subentry.data == expected_subentry_data
|
||||
assert subentry.title == f"Ollama {idx + 1}"
|
||||
|
||||
# Find the AI Task subentry
|
||||
ai_task_subentry = next(
|
||||
iter(
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == "ai_task_data"
|
||||
)
|
||||
)
|
||||
assert ai_task_subentry.subentry_type == "ai_task_data"
|
||||
assert ai_task_subentry.title == "Ollama AI Task"
|
||||
|
||||
dev = device_registry.async_get_device(
|
||||
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
|
||||
)
|
||||
@ -295,9 +332,10 @@ async def test_migration_from_v1_with_same_urls(
|
||||
|
||||
entry = entries[0]
|
||||
assert entry.version == 3
|
||||
assert entry.minor_version == 1
|
||||
assert entry.minor_version == 2
|
||||
assert not entry.options
|
||||
assert len(entry.subentries) == 2 # Two subentries from the two original entries
|
||||
# Two conversation subentries from the two original entries and 1 aitask subentry
|
||||
assert len(entry.subentries) == 3
|
||||
|
||||
# Check both subentries exist with correct data
|
||||
subentries = list(entry.subentries.values())
|
||||
@ -305,7 +343,11 @@ async def test_migration_from_v1_with_same_urls(
|
||||
assert "Ollama" in titles
|
||||
assert "Ollama 2" in titles
|
||||
|
||||
for subentry in subentries:
|
||||
conversation_subentries = [
|
||||
subentry for subentry in subentries if subentry.subentry_type == "conversation"
|
||||
]
|
||||
assert len(conversation_subentries) == 2
|
||||
for subentry in conversation_subentries:
|
||||
assert subentry.subentry_type == "conversation"
|
||||
# Subentry should include the model along with the original options
|
||||
expected_subentry_data = TEST_OPTIONS.copy()
|
||||
@ -415,10 +457,10 @@ async def test_migration_from_v2_1(
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
assert entry.version == 3
|
||||
assert entry.minor_version == 1
|
||||
assert entry.minor_version == 2
|
||||
assert not entry.options
|
||||
assert entry.title == "Ollama"
|
||||
assert len(entry.subentries) == 2
|
||||
assert len(entry.subentries) == 3
|
||||
conversation_subentries = [
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
@ -504,14 +546,44 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
|
||||
|
||||
# Check migration to v3.1
|
||||
assert mock_config_entry.version == 3
|
||||
assert mock_config_entry.minor_version == 1
|
||||
assert mock_config_entry.minor_version == 2
|
||||
|
||||
# Check that model was moved from main data to subentry
|
||||
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
assert len(mock_config_entry.subentries) == 1
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
assert subentry.data == {
|
||||
**V21_TEST_USER_DATA,
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
}
|
||||
|
||||
|
||||
async def test_migration_from_v3_1_without_subentry(hass: HomeAssistant) -> None:
|
||||
"""Test migration from version 3.1 where there is no existing subentry.
|
||||
|
||||
This exercises the code path where the model is not moved to a subentry
|
||||
because the subentry does not exist, which is a scenario that can happen
|
||||
if the user created the config entry without adding a subentry, or
|
||||
if the user manually removed the subentry after the migration to v3.1.
|
||||
"""
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
},
|
||||
version=3,
|
||||
minor_version=1,
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.ollama.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
|
||||
assert mock_config_entry.version == 3
|
||||
assert mock_config_entry.minor_version == 2
|
||||
|
||||
assert next(iter(mock_config_entry.subentries.values()), None) is None
|
||||
|
Loading…
x
Reference in New Issue
Block a user