Files
core/tests/components/anthropic/test_ai_task.py
2025-11-10 14:01:28 +01:00

212 lines
6.5 KiB
Python

"""Tests for the Anthropic integration."""
from pathlib import Path
from unittest.mock import AsyncMock, patch
import pytest
import voluptuous as vol
from homeassistant.components import ai_task, media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from . import create_content_block, create_tool_use_block
from tests.common import MockConfigEntry
async def test_generate_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation."""
entity_id = "ai_task.claude_ai_task"
# Ensure entity is linked to the subentry
entity_entry = entity_registry.async_get(entity_id)
ai_task_entry = next(
iter(
entry
for entry in mock_config_entry.subentries.values()
if entry.subentry_type == "ai_task_data"
)
)
assert entity_entry is not None
assert entity_entry.config_entry_id == mock_config_entry.entry_id
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
mock_create_stream.return_value = [create_content_block(0, ["The test data"])]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
)
assert result.data == "The test data"
async def test_generate_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
) -> None:
"""Test AI Task structured data generation."""
mock_create_stream.return_value = [
create_tool_use_block(
1,
"toolu_0123456789AbCdEfGhIjKlM",
"test_task",
['{"charac', 'ters": ["Mario', '", "Luigi"]}'],
),
]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.claude_ai_task",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)
assert result.data == {"characters": ["Mario", "Luigi"]}
async def test_generate_invalid_structured_data(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
) -> None:
"""Test AI Task with invalid JSON response."""
mock_create_stream.return_value = [
create_tool_use_block(
1,
"toolu_0123456789AbCdEfGhIjKlM",
"test_task",
"INVALID JSON RESPONSE",
)
]
with pytest.raises(
HomeAssistantError, match="Error with Claude structured response"
):
await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.claude_ai_task",
instructions="Generate test data",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)
async def test_generate_data_with_attachments(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation with attachments."""
entity_id = "ai_task.claude_ai_task"
mock_create_stream.return_value = [create_content_block(0, ["Hi there!"])]
# Test with attachments
with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=Path("doorbell_snapshot.jpg"),
),
media_source.PlayMedia(
url="http://example.com/context.pdf",
mime_type="application/pdf",
path=Path("context.pdf"),
),
],
),
patch("pathlib.Path.exists", return_value=True),
patch(
"homeassistant.components.openai_conversation.entity.guess_file_type",
return_value=("image/jpeg", None),
),
patch("pathlib.Path.read_bytes", return_value=b"fake_image_data"),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
{"media_content_id": "media-source://media/context.pdf"},
],
)
assert result.data == "Hi there!"
# Verify that the create stream was called with the correct parameters
# The last call should have the user message with attachments
call_args = mock_create_stream.call_args
assert call_args is not None
# Check that the input includes the attachments
input_messages = call_args[1]["messages"]
assert len(input_messages) > 0
# Find the user message with attachments
user_message_with_attachments = input_messages[-2]
assert user_message_with_attachments is not None
assert isinstance(user_message_with_attachments["content"], list)
assert len(user_message_with_attachments["content"]) == 3 # Text + attachments
assert user_message_with_attachments["content"] == [
{"type": "text", "text": "Test prompt"},
{
"type": "image",
"source": {
"data": "ZmFrZV9pbWFnZV9kYXRh",
"media_type": "image/jpeg",
"type": "base64",
},
},
{
"type": "document",
"source": {
"data": "ZmFrZV9pbWFnZV9kYXRh",
"media_type": "application/pdf",
"type": "base64",
},
},
]