Files
core/tests/components/ollama/test_ai_task.py
Allen Porter 4b5c04b2f0 Add AI Task support in Ollama (#148226)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
2025-07-06 16:56:37 +02:00

246 lines
7.1 KiB
Python

"""Test AI Task platform of Ollama integration."""
from unittest.mock import patch
import pytest
import voluptuous as vol
from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from tests.common import MockConfigEntry
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation."""
entity_id = "ai_task.ollama_ai_task"
# Ensure entity is linked to the subentry
entity_entry = entity_registry.async_get(entity_id)
ai_task_entry = next(
iter(
entry
for entry in mock_config_entry.subentries.values()
if entry.subentry_type == "ai_task_data"
)
)
assert entity_entry is not None
assert entity_entry.config_entry_id == mock_config_entry.entry_id
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
# Mock the Ollama chat response as an async iterator
async def mock_chat_response():
"""Mock streaming response."""
yield {
"message": {"role": "assistant", "content": "Generated test data"},
"done": True,
"done_reason": "stop",
}
with patch(
"ollama.AsyncClient.chat",
return_value=mock_chat_response(),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
)
assert result.data == "Generated test data"
@pytest.mark.usefixtures("mock_init_component")
async def test_run_task_with_streaming(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation with streaming response."""
entity_id = "ai_task.ollama_ai_task"
async def mock_stream():
"""Mock streaming response."""
yield {"message": {"role": "assistant", "content": "Stream "}}
yield {
"message": {"role": "assistant", "content": "response"},
"done": True,
"done_reason": "stop",
}
with patch(
"ollama.AsyncClient.chat",
return_value=mock_stream(),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Streaming Task",
entity_id=entity_id,
instructions="Generate streaming data",
)
assert result.data == "Stream response"
@pytest.mark.usefixtures("mock_init_component")
async def test_run_task_connection_error(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task with connection error."""
entity_id = "ai_task.ollama_ai_task"
with (
patch(
"ollama.AsyncClient.chat",
side_effect=Exception("Connection failed"),
),
pytest.raises(Exception, match="Connection failed"),
):
await ai_task.async_generate_data(
hass,
task_name="Test Error Task",
entity_id=entity_id,
instructions="Generate data that will fail",
)
@pytest.mark.usefixtures("mock_init_component")
async def test_run_task_empty_response(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task with empty response."""
entity_id = "ai_task.ollama_ai_task"
# Mock response with space (minimally non-empty)
async def mock_minimal_response():
"""Mock minimal streaming response."""
yield {
"message": {"role": "assistant", "content": " "},
"done": True,
"done_reason": "stop",
}
with patch(
"ollama.AsyncClient.chat",
return_value=mock_minimal_response(),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Minimal Task",
entity_id=entity_id,
instructions="Generate minimal data",
)
assert result.data == " "
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation."""
entity_id = "ai_task.ollama_ai_task"
# Mock the Ollama chat response as an async iterator
async def mock_chat_response():
"""Mock streaming response."""
yield {
"message": {
"role": "assistant",
"content": '{"characters": ["Mario", "Luigi"]}',
},
"done": True,
"done_reason": "stop",
}
with patch(
"ollama.AsyncClient.chat",
return_value=mock_chat_response(),
) as mock_chat:
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)
assert result.data == {"characters": ["Mario", "Luigi"]}
assert mock_chat.call_count == 1
assert mock_chat.call_args[1]["format"] == {
"type": "object",
"properties": {"characters": {"items": {"type": "string"}, "type": "array"}},
"required": ["characters"],
}
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_invalid_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation."""
entity_id = "ai_task.ollama_ai_task"
# Mock the Ollama chat response as an async iterator
async def mock_chat_response():
"""Mock streaming response."""
yield {
"message": {
"role": "assistant",
"content": "INVALID JSON RESPONSE",
},
"done": True,
"done_reason": "stop",
}
with (
patch(
"ollama.AsyncClient.chat",
return_value=mock_chat_response(),
),
pytest.raises(HomeAssistantError),
):
await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)