mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 19:09:32 +00:00
246 lines
7.1 KiB
Python
246 lines
7.1 KiB
Python
"""Test AI Task platform of Ollama integration."""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import ai_task
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import entity_registry as er, selector
|
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
|
@pytest.mark.usefixtures("mock_init_component")
|
|
async def test_generate_data(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task data generation."""
|
|
entity_id = "ai_task.ollama_ai_task"
|
|
|
|
# Ensure entity is linked to the subentry
|
|
entity_entry = entity_registry.async_get(entity_id)
|
|
ai_task_entry = next(
|
|
iter(
|
|
entry
|
|
for entry in mock_config_entry.subentries.values()
|
|
if entry.subentry_type == "ai_task_data"
|
|
)
|
|
)
|
|
assert entity_entry is not None
|
|
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
|
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
|
|
|
# Mock the Ollama chat response as an async iterator
|
|
async def mock_chat_response():
|
|
"""Mock streaming response."""
|
|
yield {
|
|
"message": {"role": "assistant", "content": "Generated test data"},
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
}
|
|
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value=mock_chat_response(),
|
|
):
|
|
result = await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Task",
|
|
entity_id=entity_id,
|
|
instructions="Generate test data",
|
|
)
|
|
|
|
assert result.data == "Generated test data"
|
|
|
|
|
|
@pytest.mark.usefixtures("mock_init_component")
|
|
async def test_run_task_with_streaming(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task data generation with streaming response."""
|
|
entity_id = "ai_task.ollama_ai_task"
|
|
|
|
async def mock_stream():
|
|
"""Mock streaming response."""
|
|
yield {"message": {"role": "assistant", "content": "Stream "}}
|
|
yield {
|
|
"message": {"role": "assistant", "content": "response"},
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
}
|
|
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value=mock_stream(),
|
|
):
|
|
result = await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Streaming Task",
|
|
entity_id=entity_id,
|
|
instructions="Generate streaming data",
|
|
)
|
|
|
|
assert result.data == "Stream response"
|
|
|
|
|
|
@pytest.mark.usefixtures("mock_init_component")
|
|
async def test_run_task_connection_error(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task with connection error."""
|
|
entity_id = "ai_task.ollama_ai_task"
|
|
|
|
with (
|
|
patch(
|
|
"ollama.AsyncClient.chat",
|
|
side_effect=Exception("Connection failed"),
|
|
),
|
|
pytest.raises(Exception, match="Connection failed"),
|
|
):
|
|
await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Error Task",
|
|
entity_id=entity_id,
|
|
instructions="Generate data that will fail",
|
|
)
|
|
|
|
|
|
@pytest.mark.usefixtures("mock_init_component")
|
|
async def test_run_task_empty_response(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task with empty response."""
|
|
entity_id = "ai_task.ollama_ai_task"
|
|
|
|
# Mock response with space (minimally non-empty)
|
|
async def mock_minimal_response():
|
|
"""Mock minimal streaming response."""
|
|
yield {
|
|
"message": {"role": "assistant", "content": " "},
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
}
|
|
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value=mock_minimal_response(),
|
|
):
|
|
result = await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Minimal Task",
|
|
entity_id=entity_id,
|
|
instructions="Generate minimal data",
|
|
)
|
|
|
|
assert result.data == " "
|
|
|
|
|
|
@pytest.mark.usefixtures("mock_init_component")
|
|
async def test_generate_structured_data(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task data generation."""
|
|
entity_id = "ai_task.ollama_ai_task"
|
|
|
|
# Mock the Ollama chat response as an async iterator
|
|
async def mock_chat_response():
|
|
"""Mock streaming response."""
|
|
yield {
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": '{"characters": ["Mario", "Luigi"]}',
|
|
},
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
}
|
|
|
|
with patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value=mock_chat_response(),
|
|
) as mock_chat:
|
|
result = await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Task",
|
|
entity_id=entity_id,
|
|
instructions="Generate test data",
|
|
structure=vol.Schema(
|
|
{
|
|
vol.Required("characters"): selector.selector(
|
|
{
|
|
"text": {
|
|
"multiple": True,
|
|
}
|
|
}
|
|
)
|
|
},
|
|
),
|
|
)
|
|
assert result.data == {"characters": ["Mario", "Luigi"]}
|
|
|
|
assert mock_chat.call_count == 1
|
|
assert mock_chat.call_args[1]["format"] == {
|
|
"type": "object",
|
|
"properties": {"characters": {"items": {"type": "string"}, "type": "array"}},
|
|
"required": ["characters"],
|
|
}
|
|
|
|
|
|
@pytest.mark.usefixtures("mock_init_component")
|
|
async def test_generate_invalid_structured_data(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task data generation."""
|
|
entity_id = "ai_task.ollama_ai_task"
|
|
|
|
# Mock the Ollama chat response as an async iterator
|
|
async def mock_chat_response():
|
|
"""Mock streaming response."""
|
|
yield {
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "INVALID JSON RESPONSE",
|
|
},
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
}
|
|
|
|
with (
|
|
patch(
|
|
"ollama.AsyncClient.chat",
|
|
return_value=mock_chat_response(),
|
|
),
|
|
pytest.raises(HomeAssistantError),
|
|
):
|
|
await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Task",
|
|
entity_id=entity_id,
|
|
instructions="Generate test data",
|
|
structure=vol.Schema(
|
|
{
|
|
vol.Required("characters"): selector.selector(
|
|
{
|
|
"text": {
|
|
"multiple": True,
|
|
}
|
|
}
|
|
)
|
|
},
|
|
),
|
|
)
|