mirror of
https://github.com/home-assistant/core.git
synced 2025-11-13 13:00:11 +00:00
212 lines
6.5 KiB
Python
212 lines
6.5 KiB
Python
"""Tests for the Anthropic integration."""
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import ai_task, media_source
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import entity_registry as er, selector
|
|
|
|
from . import create_content_block, create_tool_use_block
|
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
|
async def test_generate_data(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task data generation."""
|
|
entity_id = "ai_task.claude_ai_task"
|
|
|
|
# Ensure entity is linked to the subentry
|
|
entity_entry = entity_registry.async_get(entity_id)
|
|
ai_task_entry = next(
|
|
iter(
|
|
entry
|
|
for entry in mock_config_entry.subentries.values()
|
|
if entry.subentry_type == "ai_task_data"
|
|
)
|
|
)
|
|
assert entity_entry is not None
|
|
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
|
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
|
|
|
mock_create_stream.return_value = [create_content_block(0, ["The test data"])]
|
|
|
|
result = await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Task",
|
|
entity_id=entity_id,
|
|
instructions="Generate test data",
|
|
)
|
|
|
|
assert result.data == "The test data"
|
|
|
|
|
|
async def test_generate_structured_data(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
) -> None:
|
|
"""Test AI Task structured data generation."""
|
|
mock_create_stream.return_value = [
|
|
create_tool_use_block(
|
|
1,
|
|
"toolu_0123456789AbCdEfGhIjKlM",
|
|
"test_task",
|
|
['{"charac', 'ters": ["Mario', '", "Luigi"]}'],
|
|
),
|
|
]
|
|
|
|
result = await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Task",
|
|
entity_id="ai_task.claude_ai_task",
|
|
instructions="Generate test data",
|
|
structure=vol.Schema(
|
|
{
|
|
vol.Required("characters"): selector.selector(
|
|
{
|
|
"text": {
|
|
"multiple": True,
|
|
}
|
|
}
|
|
)
|
|
},
|
|
),
|
|
)
|
|
|
|
assert result.data == {"characters": ["Mario", "Luigi"]}
|
|
|
|
|
|
async def test_generate_invalid_structured_data(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
) -> None:
|
|
"""Test AI Task with invalid JSON response."""
|
|
mock_create_stream.return_value = [
|
|
create_tool_use_block(
|
|
1,
|
|
"toolu_0123456789AbCdEfGhIjKlM",
|
|
"test_task",
|
|
"INVALID JSON RESPONSE",
|
|
)
|
|
]
|
|
|
|
with pytest.raises(
|
|
HomeAssistantError, match="Error with Claude structured response"
|
|
):
|
|
await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Task",
|
|
entity_id="ai_task.claude_ai_task",
|
|
instructions="Generate test data",
|
|
structure=vol.Schema(
|
|
{
|
|
vol.Required("characters"): selector.selector(
|
|
{
|
|
"text": {
|
|
"multiple": True,
|
|
}
|
|
}
|
|
)
|
|
},
|
|
),
|
|
)
|
|
|
|
|
|
async def test_generate_data_with_attachments(
|
|
hass: HomeAssistant,
|
|
mock_config_entry: MockConfigEntry,
|
|
mock_init_component,
|
|
mock_create_stream: AsyncMock,
|
|
entity_registry: er.EntityRegistry,
|
|
) -> None:
|
|
"""Test AI Task data generation with attachments."""
|
|
entity_id = "ai_task.claude_ai_task"
|
|
|
|
mock_create_stream.return_value = [create_content_block(0, ["Hi there!"])]
|
|
|
|
# Test with attachments
|
|
with (
|
|
patch(
|
|
"homeassistant.components.media_source.async_resolve_media",
|
|
side_effect=[
|
|
media_source.PlayMedia(
|
|
url="http://example.com/doorbell_snapshot.jpg",
|
|
mime_type="image/jpeg",
|
|
path=Path("doorbell_snapshot.jpg"),
|
|
),
|
|
media_source.PlayMedia(
|
|
url="http://example.com/context.pdf",
|
|
mime_type="application/pdf",
|
|
path=Path("context.pdf"),
|
|
),
|
|
],
|
|
),
|
|
patch("pathlib.Path.exists", return_value=True),
|
|
patch(
|
|
"homeassistant.components.openai_conversation.entity.guess_file_type",
|
|
return_value=("image/jpeg", None),
|
|
),
|
|
patch("pathlib.Path.read_bytes", return_value=b"fake_image_data"),
|
|
):
|
|
result = await ai_task.async_generate_data(
|
|
hass,
|
|
task_name="Test Task",
|
|
entity_id=entity_id,
|
|
instructions="Test prompt",
|
|
attachments=[
|
|
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
|
|
{"media_content_id": "media-source://media/context.pdf"},
|
|
],
|
|
)
|
|
|
|
assert result.data == "Hi there!"
|
|
|
|
# Verify that the create stream was called with the correct parameters
|
|
# The last call should have the user message with attachments
|
|
call_args = mock_create_stream.call_args
|
|
assert call_args is not None
|
|
|
|
# Check that the input includes the attachments
|
|
input_messages = call_args[1]["messages"]
|
|
assert len(input_messages) > 0
|
|
|
|
# Find the user message with attachments
|
|
user_message_with_attachments = input_messages[-2]
|
|
|
|
assert user_message_with_attachments is not None
|
|
assert isinstance(user_message_with_attachments["content"], list)
|
|
assert len(user_message_with_attachments["content"]) == 3 # Text + attachments
|
|
assert user_message_with_attachments["content"] == [
|
|
{"type": "text", "text": "Test prompt"},
|
|
{
|
|
"type": "image",
|
|
"source": {
|
|
"data": "ZmFrZV9pbWFnZV9kYXRh",
|
|
"media_type": "image/jpeg",
|
|
"type": "base64",
|
|
},
|
|
},
|
|
{
|
|
"type": "document",
|
|
"source": {
|
|
"data": "ZmFrZV9pbWFnZV9kYXRh",
|
|
"media_type": "application/pdf",
|
|
"type": "base64",
|
|
},
|
|
},
|
|
]
|