diff --git a/homeassistant/components/google_generative_ai_conversation/ai_task.py b/homeassistant/components/google_generative_ai_conversation/ai_task.py index ab34af71ebe..b4f9d73e38d 100644 --- a/homeassistant/components/google_generative_ai_conversation/ai_task.py +++ b/homeassistant/components/google_generative_ai_conversation/ai_task.py @@ -2,11 +2,14 @@ from __future__ import annotations +from json import JSONDecodeError + from homeassistant.components import ai_task, conversation from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.util.json import json_loads from .const import LOGGER from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity @@ -42,7 +45,7 @@ class GoogleGenerativeAITaskEntity( chat_log: conversation.ChatLog, ) -> ai_task.GenDataTaskResult: """Handle a generate data task.""" - await self._async_handle_chat_log(chat_log) + await self._async_handle_chat_log(chat_log, task.structure) if not isinstance(chat_log.content[-1], conversation.AssistantContent): LOGGER.error( @@ -51,7 +54,25 @@ class GoogleGenerativeAITaskEntity( ) raise HomeAssistantError(ERROR_GETTING_RESPONSE) + text = chat_log.content[-1].content or "" + + if not task.structure: + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=text, + ) + + try: + data = json_loads(text) + except JSONDecodeError as err: + LOGGER.error( + "Failed to parse JSON response: %s. Response: %s", + err, + text, + ) + raise HomeAssistantError(ERROR_GETTING_RESPONSE) from err + return ai_task.GenDataTaskResult( conversation_id=chat_log.conversation_id, - data=chat_log.content[-1].content or "", + data=data, ) diff --git a/homeassistant/components/google_generative_ai_conversation/entity.py b/homeassistant/components/google_generative_ai_conversation/entity.py index dea875212ef..d471da36a8c 100644 --- a/homeassistant/components/google_generative_ai_conversation/entity.py +++ b/homeassistant/components/google_generative_ai_conversation/entity.py @@ -21,6 +21,7 @@ from google.genai.types import ( Schema, Tool, ) +import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import conversation @@ -324,6 +325,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity): async def _async_handle_chat_log( self, chat_log: conversation.ChatLog, + structure: vol.Schema | None = None, ) -> None: """Generate an answer for the chat log.""" options = self.subentry.data @@ -402,6 +404,18 @@ class GoogleGenerativeAILLMBaseEntity(Entity): generateContentConfig.automatic_function_calling = ( AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None) ) + if structure: + generateContentConfig.response_mime_type = "application/json" + generateContentConfig.response_schema = _format_schema( + convert( + structure, + custom_serializer=( + chat_log.llm_api.custom_serializer + if chat_log.llm_api + else llm.selector_serializer + ), + ) + ) if not supports_system_instruction: messages = [ diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index a8a598c79f8..b239ad99119 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -458,7 +458,7 @@ class AssistAPI(API): api_prompt=self._async_get_api_prompt(llm_context, exposed_entities), llm_context=llm_context, tools=self._async_get_tools(llm_context, exposed_entities), - custom_serializer=_selector_serializer, + custom_serializer=selector_serializer, ) @callback @@ -701,7 +701,7 @@ def _get_exposed_entities( return data -def _selector_serializer(schema: Any) -> Any: # noqa: C901 +def selector_serializer(schema: Any) -> Any: # noqa: C901 """Convert selectors into OpenAPI schema.""" if not isinstance(schema, selector.Selector): return UNSUPPORTED @@ -782,7 +782,7 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901 result["properties"] = { field: convert( selector.selector(field_schema["selector"]), - custom_serializer=_selector_serializer, + custom_serializer=selector_serializer, ) for field, field_schema in fields.items() } diff --git a/tests/components/google_generative_ai_conversation/conftest.py b/tests/components/google_generative_ai_conversation/conftest.py index 244ac518fbd..da5976f46c4 100644 --- a/tests/components/google_generative_ai_conversation/conftest.py +++ b/tests/components/google_generative_ai_conversation/conftest.py @@ -112,19 +112,26 @@ async def setup_ha(hass: HomeAssistant) -> None: @pytest.fixture -def mock_send_message_stream() -> Generator[AsyncMock]: +def mock_chat_create() -> Generator[AsyncMock]: """Mock stream response.""" async def mock_generator(stream): for value in stream: yield value - with patch( - "google.genai.chats.AsyncChat.send_message_stream", - AsyncMock(), - ) as mock_send_message_stream: - mock_send_message_stream.side_effect = lambda **kwargs: mock_generator( - mock_send_message_stream.return_value.pop(0) - ) + mock_send_message_stream = AsyncMock() + mock_send_message_stream.side_effect = lambda **kwargs: mock_generator( + mock_send_message_stream.return_value.pop(0) + ) - yield mock_send_message_stream + with patch( + "google.genai.chats.AsyncChats.create", + return_value=AsyncMock(send_message_stream=mock_send_message_stream), + ) as mock_create: + yield mock_create + + +@pytest.fixture +def mock_send_message_stream(mock_chat_create) -> Generator[AsyncMock]: + """Mock stream response.""" + return mock_chat_create.return_value.send_message_stream diff --git a/tests/components/google_generative_ai_conversation/test_ai_task.py b/tests/components/google_generative_ai_conversation/test_ai_task.py index 72b62b64615..b2b44aa1cd6 100644 --- a/tests/components/google_generative_ai_conversation/test_ai_task.py +++ b/tests/components/google_generative_ai_conversation/test_ai_task.py @@ -4,10 +4,12 @@ from unittest.mock import AsyncMock from google.genai.types import GenerateContentResponse import pytest +import voluptuous as vol from homeassistant.components import ai_task from homeassistant.core import HomeAssistant -from homeassistant.helpers import entity_registry as er +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import entity_registry as er, selector from tests.common import MockConfigEntry from tests.components.conversation import ( @@ -17,14 +19,15 @@ from tests.components.conversation import ( @pytest.mark.usefixtures("mock_init_component") -async def test_run_task( +async def test_generate_data( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_chat_log: MockChatLog, # noqa: F811 mock_send_message_stream: AsyncMock, + mock_chat_create: AsyncMock, entity_registry: er.EntityRegistry, ) -> None: - """Test empty response.""" + """Test generating data.""" entity_id = "ai_task.google_ai_task" # Ensure it's linked to the subentry @@ -60,3 +63,68 @@ async def test_run_task( instructions="Test prompt", ) assert result.data == "Hi there!" + + mock_send_message_stream.return_value = [ + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [{"text": '{"characters": ["Mario", "Luigi"]}'}], + "role": "model", + }, + } + ], + ), + ], + ] + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Give me 2 mario characters", + structure=vol.Schema( + { + vol.Required("characters"): selector.selector( + { + "text": { + "multiple": True, + } + } + ) + }, + ), + ) + assert result.data == {"characters": ["Mario", "Luigi"]} + + assert len(mock_chat_create.mock_calls) == 2 + config = mock_chat_create.mock_calls[-1][2]["config"] + assert config.response_mime_type == "application/json" + assert config.response_schema == { + "properties": {"characters": {"items": {"type": "STRING"}, "type": "ARRAY"}}, + "required": ["characters"], + "type": "OBJECT", + } + # Raise error on invalid JSON response + mock_send_message_stream.return_value = [ + [ + GenerateContentResponse( + candidates=[ + { + "content": { + "parts": [{"text": "INVALID JSON RESPONSE"}], + "role": "model", + }, + } + ], + ), + ], + ] + with pytest.raises(HomeAssistantError): + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Test prompt", + structure=vol.Schema({vol.Required("bla"): str}), + )