mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Add attachment support in ollama ai task (#148981)
This commit is contained in:
parent
380c737901
commit
f90e06fde1
@ -39,7 +39,10 @@ class OllamaTaskEntity(
|
|||||||
):
|
):
|
||||||
"""Ollama AI Task entity."""
|
"""Ollama AI Task entity."""
|
||||||
|
|
||||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
|
_attr_supported_features = (
|
||||||
|
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||||
|
)
|
||||||
|
|
||||||
async def _async_generate_data(
|
async def _async_generate_data(
|
||||||
self,
|
self,
|
||||||
|
@ -106,9 +106,18 @@ def _convert_content(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
if isinstance(chat_content, conversation.UserContent):
|
if isinstance(chat_content, conversation.UserContent):
|
||||||
|
images: list[ollama.Image] = []
|
||||||
|
for attachment in chat_content.attachments or ():
|
||||||
|
if not attachment.mime_type.startswith("image/"):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_key="unsupported_attachment_type",
|
||||||
|
)
|
||||||
|
images.append(ollama.Image(value=attachment.path))
|
||||||
return ollama.Message(
|
return ollama.Message(
|
||||||
role=MessageRole.USER.value,
|
role=MessageRole.USER.value,
|
||||||
content=chat_content.content,
|
content=chat_content.content,
|
||||||
|
images=images or None,
|
||||||
)
|
)
|
||||||
if isinstance(chat_content, conversation.SystemContent):
|
if isinstance(chat_content, conversation.SystemContent):
|
||||||
return ollama.Message(
|
return ollama.Message(
|
||||||
|
@ -94,5 +94,10 @@
|
|||||||
"download": "[%key:component::ollama::config_subentries::conversation::progress::download%]"
|
"download": "[%key:component::ollama::config_subentries::conversation::progress::download%]"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"exceptions": {
|
||||||
|
"unsupported_attachment_type": {
|
||||||
|
"message": "Ollama only supports image attachments in user content, but received non-image attachment."
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
"""Test AI Task platform of Ollama integration."""
|
"""Test AI Task platform of Ollama integration."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import ollama
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import ai_task
|
from homeassistant.components import ai_task, media_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import entity_registry as er, selector
|
from homeassistant.helpers import entity_registry as er, selector
|
||||||
@ -243,3 +245,115 @@ async def test_generate_invalid_structured_data(
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_data_with_attachment(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task data generation with image attachments."""
|
||||||
|
entity_id = "ai_task.ollama_ai_task"
|
||||||
|
|
||||||
|
# Mock the Ollama chat response as an async iterator
|
||||||
|
async def mock_chat_response():
|
||||||
|
"""Mock streaming response."""
|
||||||
|
yield {
|
||||||
|
"message": {"role": "assistant", "content": "Generated test data"},
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
side_effect=[
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/doorbell_snapshot.jpg",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
path=Path("doorbell_snapshot.jpg"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
return_value=mock_chat_response(),
|
||||||
|
) as mock_chat,
|
||||||
|
):
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Generate test data",
|
||||||
|
attachments=[
|
||||||
|
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.data == "Generated test data"
|
||||||
|
|
||||||
|
assert mock_chat.call_count == 1
|
||||||
|
messages = mock_chat.call_args[1]["messages"]
|
||||||
|
assert len(messages) == 2
|
||||||
|
chat_message = messages[1]
|
||||||
|
assert chat_message.role == "user"
|
||||||
|
assert chat_message.content == "Generate test data"
|
||||||
|
assert chat_message.images == [
|
||||||
|
ollama.Image(value=Path("doorbell_snapshot.jpg")),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_data_with_unsupported_file_format(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task data generation with image attachments."""
|
||||||
|
entity_id = "ai_task.ollama_ai_task"
|
||||||
|
|
||||||
|
# Mock the Ollama chat response as an async iterator
|
||||||
|
async def mock_chat_response():
|
||||||
|
"""Mock streaming response."""
|
||||||
|
yield {
|
||||||
|
"message": {"role": "assistant", "content": "Generated test data"},
|
||||||
|
"done": True,
|
||||||
|
"done_reason": "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
side_effect=[
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/doorbell_snapshot.jpg",
|
||||||
|
mime_type="image/jpeg",
|
||||||
|
path=Path("doorbell_snapshot.jpg"),
|
||||||
|
),
|
||||||
|
media_source.PlayMedia(
|
||||||
|
url="http://example.com/context.txt",
|
||||||
|
mime_type="text/plain",
|
||||||
|
path=Path("context.txt"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"ollama.AsyncClient.chat",
|
||||||
|
return_value=mock_chat_response(),
|
||||||
|
),
|
||||||
|
pytest.raises(
|
||||||
|
HomeAssistantError,
|
||||||
|
match="Ollama only supports image attachments in user content",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Generate test data",
|
||||||
|
attachments=[
|
||||||
|
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
||||||
|
{"media_content_id": "media-source://media/context.txt"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user