"""Test initialization of the AI Task component.""" from typing import Any from unittest.mock import patch from freezegun.api import FrozenDateTimeFactory import pytest import voluptuous as vol from homeassistant.components import media_source from homeassistant.components.ai_task import AITaskPreferences from homeassistant.components.ai_task.const import DATA_PREFERENCES from homeassistant.core import HomeAssistant from homeassistant.helpers import selector from .conftest import TEST_ENTITY_ID, MockAITaskEntity from tests.common import flush_store async def test_preferences_storage_load( hass: HomeAssistant, ) -> None: """Test that AITaskPreferences are stored and loaded correctly.""" preferences = AITaskPreferences(hass) await preferences.async_load() # Initial state should be None for entity IDs for key in AITaskPreferences.KEYS: assert getattr(preferences, key) is None, f"Initial {key} should be None" new_values = {key: f"ai_task.test_{key}" for key in AITaskPreferences.KEYS} preferences.async_set_preferences(**new_values) # Verify that current preferences object is updated for key, value in new_values.items(): assert getattr(preferences, key) == value, ( f"Current {key} should match set value" ) await flush_store(preferences._store) # Create a new preferences instance to test loading from store new_preferences_instance = AITaskPreferences(hass) await new_preferences_instance.async_load() for key in AITaskPreferences.KEYS: assert getattr(preferences, key) == getattr(new_preferences_instance, key), ( f"Loaded {key} should match saved value" ) @pytest.mark.parametrize( ("set_preferences", "msg_extra"), [ ( {"gen_data_entity_id": TEST_ENTITY_ID}, {}, ), ( {}, { "entity_id": TEST_ENTITY_ID, "attachments": [ { "media_content_id": "media-source://mock/blah_blah_blah.mp4", "media_content_type": "video/mp4", } ], }, ), ], ) async def test_generate_data_service( hass: HomeAssistant, init_components: None, freezer: FrozenDateTimeFactory, set_preferences: dict[str, str | None], msg_extra: dict[str, str], mock_ai_task_entity: MockAITaskEntity, ) -> None: """Test the generate data service.""" preferences = hass.data[DATA_PREFERENCES] preferences.async_set_preferences(**set_preferences) with patch( "homeassistant.components.media_source.async_resolve_media", return_value=media_source.PlayMedia( url="http://example.com/media.mp4", mime_type="video/mp4", ), ): result = await hass.services.async_call( "ai_task", "generate_data", { "task_name": "Test Name", "instructions": "Test prompt", } | msg_extra, blocking=True, return_response=True, ) assert result["data"] == "Mock result" assert len(mock_ai_task_entity.mock_generate_data_tasks) == 1 task = mock_ai_task_entity.mock_generate_data_tasks[0] assert len(task.attachments or []) == len( msg_attachments := msg_extra.get("attachments", []) ) for msg_attachment, attachment in zip( msg_attachments, task.attachments or [], strict=False ): assert attachment.url == "http://example.com/media.mp4" assert attachment.mime_type == "video/mp4" assert attachment.media_content_id == msg_attachment["media_content_id"] assert ( str(attachment) == f"" ) async def test_generate_data_service_structure_fields( hass: HomeAssistant, init_components: None, mock_ai_task_entity: MockAITaskEntity, ) -> None: """Test the entity can generate structured data with a top level object schema.""" result = await hass.services.async_call( "ai_task", "generate_data", { "task_name": "Profile Generation", "instructions": "Please generate a profile for a new user", "entity_id": TEST_ENTITY_ID, "structure": { "name": { "description": "First and last name of the user such as Alice Smith", "required": True, "selector": {"text": {}}, }, "age": { "description": "Age of the user", "selector": { "number": { "min": 0, "max": 120, } }, }, }, }, blocking=True, return_response=True, ) # Arbitrary data returned by the mock entity (not determined by above schema in test) assert result["data"] == { "name": "Tracy Chen", "age": 30, } assert mock_ai_task_entity.mock_generate_data_tasks task = mock_ai_task_entity.mock_generate_data_tasks[0] assert task.instructions == "Please generate a profile for a new user" assert task.structure assert isinstance(task.structure, vol.Schema) schema = list(task.structure.schema.items()) assert len(schema) == 2 name_key, name_value = schema[0] assert name_key == "name" assert isinstance(name_key, vol.Required) assert name_key.description == "First and last name of the user such as Alice Smith" assert isinstance(name_value, selector.TextSelector) age_key, age_value = schema[1] assert age_key == "age" assert isinstance(age_key, vol.Optional) assert age_key.description == "Age of the user" assert isinstance(age_value, selector.NumberSelector) assert age_value.config["min"] == 0 assert age_value.config["max"] == 120 @pytest.mark.parametrize( ("structure", "expected_exception", "expected_error"), [ ( { "name": { "description": "First and last name of the user such as Alice Smith", "selector": {"invalid-selector": {}}, }, }, vol.Invalid, r"Unknown selector type invalid-selector.*", ), ( { "name": { "description": "First and last name of the user such as Alice Smith", "selector": { "text": { "extra-config": False, } }, }, }, vol.Invalid, r"extra keys not allowed.*", ), ( { "name": { "description": "First and last name of the user such as Alice Smith", }, }, vol.Invalid, r"required key not provided.*selector.*", ), (12345, vol.Invalid, r"xpected a dictionary.*"), ("name", vol.Invalid, r"xpected a dictionary.*"), (["name"], vol.Invalid, r"xpected a dictionary.*"), ( { "name": { "description": "First and last name of the user such as Alice Smith", "selector": {"text": {}}, "extra-fields": "Some extra fields", }, }, vol.Invalid, r"extra keys not allowed .*", ), ( { "name": { "description": "First and last name of the user such as Alice Smith", "selector": "invalid-schema", }, }, vol.Invalid, r"xpected a dictionary for dictionary.", ), ], ids=( "invalid-selector", "invalid-selector-config", "missing-selector", "structure-is-int-not-object", "structure-is-str-not-object", "structure-is-list-not-object", "extra-fields", "invalid-selector-schema", ), ) async def test_generate_data_service_invalid_structure( hass: HomeAssistant, init_components: None, structure: Any, expected_exception: Exception, expected_error: str, ) -> None: """Test the entity can generate structured data.""" with pytest.raises(expected_exception, match=expected_error): await hass.services.async_call( "ai_task", "generate_data", { "task_name": "Profile Generation", "instructions": "Please generate a profile for a new user", "entity_id": TEST_ENTITY_ID, "structure": structure, }, blocking=True, return_response=True, )