"""Tests for the Anthropic integration.""" from pathlib import Path from unittest.mock import AsyncMock, patch import pytest import voluptuous as vol from homeassistant.components import ai_task, media_source from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry as er, selector from . import create_content_block, create_tool_use_block from tests.common import MockConfigEntry async def test_generate_data( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component, mock_create_stream: AsyncMock, entity_registry: er.EntityRegistry, ) -> None: """Test AI Task data generation.""" entity_id = "ai_task.claude_ai_task" # Ensure entity is linked to the subentry entity_entry = entity_registry.async_get(entity_id) ai_task_entry = next( iter( entry for entry in mock_config_entry.subentries.values() if entry.subentry_type == "ai_task_data" ) ) assert entity_entry is not None assert entity_entry.config_entry_id == mock_config_entry.entry_id assert entity_entry.config_subentry_id == ai_task_entry.subentry_id mock_create_stream.return_value = [create_content_block(0, ["The test data"])] result = await ai_task.async_generate_data( hass, task_name="Test Task", entity_id=entity_id, instructions="Generate test data", ) assert result.data == "The test data" async def test_generate_structured_data( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component, mock_create_stream: AsyncMock, ) -> None: """Test AI Task structured data generation.""" mock_create_stream.return_value = [ create_tool_use_block( 1, "toolu_0123456789AbCdEfGhIjKlM", "test_task", ['{"charac', 'ters": ["Mario', '", "Luigi"]}'], ), ] result = await ai_task.async_generate_data( hass, task_name="Test Task", entity_id="ai_task.claude_ai_task", instructions="Generate test data", structure=vol.Schema( { vol.Required("characters"): selector.selector( { "text": { "multiple": True, } } ) }, ), ) assert result.data == {"characters": ["Mario", "Luigi"]} async def test_generate_invalid_structured_data( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component, mock_create_stream: AsyncMock, ) -> None: """Test AI Task with invalid JSON response.""" mock_create_stream.return_value = [ create_tool_use_block( 1, "toolu_0123456789AbCdEfGhIjKlM", "test_task", "INVALID JSON RESPONSE", ) ] with pytest.raises( HomeAssistantError, match="Error with Claude structured response" ): await ai_task.async_generate_data( hass, task_name="Test Task", entity_id="ai_task.claude_ai_task", instructions="Generate test data", structure=vol.Schema( { vol.Required("characters"): selector.selector( { "text": { "multiple": True, } } ) }, ), ) async def test_generate_data_with_attachments( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component, mock_create_stream: AsyncMock, entity_registry: er.EntityRegistry, ) -> None: """Test AI Task data generation with attachments.""" entity_id = "ai_task.claude_ai_task" mock_create_stream.return_value = [create_content_block(0, ["Hi there!"])] # Test with attachments with ( patch( "homeassistant.components.media_source.async_resolve_media", side_effect=[ media_source.PlayMedia( url="http://example.com/doorbell_snapshot.jpg", mime_type="image/jpeg", path=Path("doorbell_snapshot.jpg"), ), media_source.PlayMedia( url="http://example.com/context.pdf", mime_type="application/pdf", path=Path("context.pdf"), ), ], ), patch("pathlib.Path.exists", return_value=True), patch( "homeassistant.components.openai_conversation.entity.guess_file_type", return_value=("image/jpeg", None), ), patch("pathlib.Path.read_bytes", return_value=b"fake_image_data"), ): result = await ai_task.async_generate_data( hass, task_name="Test Task", entity_id=entity_id, instructions="Test prompt", attachments=[ {"media_content_id": "media-source://media/doorbell_snapshot.jpg"}, {"media_content_id": "media-source://media/context.pdf"}, ], ) assert result.data == "Hi there!" # Verify that the create stream was called with the correct parameters # The last call should have the user message with attachments call_args = mock_create_stream.call_args assert call_args is not None # Check that the input includes the attachments input_messages = call_args[1]["messages"] assert len(input_messages) > 0 # Find the user message with attachments user_message_with_attachments = input_messages[-2] assert user_message_with_attachments is not None assert isinstance(user_message_with_attachments["content"], list) assert len(user_message_with_attachments["content"]) == 3 # Text + attachments assert user_message_with_attachments["content"] == [ {"type": "text", "text": "Test prompt"}, { "type": "image", "source": { "data": "ZmFrZV9pbWFnZV9kYXRh", "media_type": "image/jpeg", "type": "base64", }, }, { "type": "document", "source": { "data": "ZmFrZV9pbWFnZV9kYXRh", "media_type": "application/pdf", "type": "base64", }, }, ]