Add attachment support in ollama ai task (#148981)

This commit is contained in:
Allen Porter 2025-07-18 22:27:48 -07:00 committed by GitHub
parent 380c737901
commit f90e06fde1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 133 additions and 2 deletions

View File

@ -39,7 +39,10 @@ class OllamaTaskEntity(
): ):
"""Ollama AI Task entity.""" """Ollama AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA _attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
async def _async_generate_data( async def _async_generate_data(
self, self,

View File

@ -106,9 +106,18 @@ def _convert_content(
], ],
) )
if isinstance(chat_content, conversation.UserContent): if isinstance(chat_content, conversation.UserContent):
images: list[ollama.Image] = []
for attachment in chat_content.attachments or ():
if not attachment.mime_type.startswith("image/"):
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="unsupported_attachment_type",
)
images.append(ollama.Image(value=attachment.path))
return ollama.Message( return ollama.Message(
role=MessageRole.USER.value, role=MessageRole.USER.value,
content=chat_content.content, content=chat_content.content,
images=images or None,
) )
if isinstance(chat_content, conversation.SystemContent): if isinstance(chat_content, conversation.SystemContent):
return ollama.Message( return ollama.Message(

View File

@ -94,5 +94,10 @@
"download": "[%key:component::ollama::config_subentries::conversation::progress::download%]" "download": "[%key:component::ollama::config_subentries::conversation::progress::download%]"
} }
} }
},
"exceptions": {
"unsupported_attachment_type": {
"message": "Ollama only supports image attachments in user content, but received non-image attachment."
}
} }
} }

View File

@ -1,11 +1,13 @@
"""Test AI Task platform of Ollama integration.""" """Test AI Task platform of Ollama integration."""
from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import ollama
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components import ai_task from homeassistant.components import ai_task, media_source
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector from homeassistant.helpers import entity_registry as er, selector
@ -243,3 +245,115 @@ async def test_generate_invalid_structured_data(
}, },
), ),
) )
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_data_with_attachment(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation with image attachments."""
entity_id = "ai_task.ollama_ai_task"
# Mock the Ollama chat response as an async iterator
async def mock_chat_response():
"""Mock streaming response."""
yield {
"message": {"role": "assistant", "content": "Generated test data"},
"done": True,
"done_reason": "stop",
}
with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=Path("doorbell_snapshot.jpg"),
),
],
),
patch(
"ollama.AsyncClient.chat",
return_value=mock_chat_response(),
) as mock_chat,
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
],
)
assert result.data == "Generated test data"
assert mock_chat.call_count == 1
messages = mock_chat.call_args[1]["messages"]
assert len(messages) == 2
chat_message = messages[1]
assert chat_message.role == "user"
assert chat_message.content == "Generate test data"
assert chat_message.images == [
ollama.Image(value=Path("doorbell_snapshot.jpg")),
]
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_data_with_unsupported_file_format(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation with image attachments."""
entity_id = "ai_task.ollama_ai_task"
# Mock the Ollama chat response as an async iterator
async def mock_chat_response():
"""Mock streaming response."""
yield {
"message": {"role": "assistant", "content": "Generated test data"},
"done": True,
"done_reason": "stop",
}
with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=Path("doorbell_snapshot.jpg"),
),
media_source.PlayMedia(
url="http://example.com/context.txt",
mime_type="text/plain",
path=Path("context.txt"),
),
],
),
patch(
"ollama.AsyncClient.chat",
return_value=mock_chat_response(),
),
pytest.raises(
HomeAssistantError,
match="Ollama only supports image attachments in user content",
),
):
await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
{"media_content_id": "media-source://media/context.txt"},
],
)