Add Google Gen AI structured data support (#148143)

This commit is contained in:
Paulus Schoutsen 2025-07-05 17:22:17 +02:00 committed by GitHub
parent 33d05d99eb
commit 4f4ec6f41a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 127 additions and 17 deletions

View File

@ -2,11 +2,14 @@
from __future__ import annotations from __future__ import annotations
from json import JSONDecodeError
from homeassistant.components import ai_task, conversation from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .const import LOGGER from .const import LOGGER
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
@ -42,7 +45,7 @@ class GoogleGenerativeAITaskEntity(
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult: ) -> ai_task.GenDataTaskResult:
"""Handle a generate data task.""" """Handle a generate data task."""
await self._async_handle_chat_log(chat_log) await self._async_handle_chat_log(chat_log, task.structure)
if not isinstance(chat_log.content[-1], conversation.AssistantContent): if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error( LOGGER.error(
@ -51,7 +54,25 @@ class GoogleGenerativeAITaskEntity(
) )
raise HomeAssistantError(ERROR_GETTING_RESPONSE) raise HomeAssistantError(ERROR_GETTING_RESPONSE)
text = chat_log.content[-1].content or ""
if not task.structure:
return ai_task.GenDataTaskResult( return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id, conversation_id=chat_log.conversation_id,
data=chat_log.content[-1].content or "", data=text,
)
try:
data = json_loads(text)
except JSONDecodeError as err:
LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE) from err
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
) )

View File

@ -21,6 +21,7 @@ from google.genai.types import (
Schema, Schema,
Tool, Tool,
) )
import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import conversation from homeassistant.components import conversation
@ -324,6 +325,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
async def _async_handle_chat_log( async def _async_handle_chat_log(
self, self,
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
structure: vol.Schema | None = None,
) -> None: ) -> None:
"""Generate an answer for the chat log.""" """Generate an answer for the chat log."""
options = self.subentry.data options = self.subentry.data
@ -402,6 +404,18 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
generateContentConfig.automatic_function_calling = ( generateContentConfig.automatic_function_calling = (
AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None) AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None)
) )
if structure:
generateContentConfig.response_mime_type = "application/json"
generateContentConfig.response_schema = _format_schema(
convert(
structure,
custom_serializer=(
chat_log.llm_api.custom_serializer
if chat_log.llm_api
else llm.selector_serializer
),
)
)
if not supports_system_instruction: if not supports_system_instruction:
messages = [ messages = [

View File

@ -458,7 +458,7 @@ class AssistAPI(API):
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities), api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
llm_context=llm_context, llm_context=llm_context,
tools=self._async_get_tools(llm_context, exposed_entities), tools=self._async_get_tools(llm_context, exposed_entities),
custom_serializer=_selector_serializer, custom_serializer=selector_serializer,
) )
@callback @callback
@ -701,7 +701,7 @@ def _get_exposed_entities(
return data return data
def _selector_serializer(schema: Any) -> Any: # noqa: C901 def selector_serializer(schema: Any) -> Any: # noqa: C901
"""Convert selectors into OpenAPI schema.""" """Convert selectors into OpenAPI schema."""
if not isinstance(schema, selector.Selector): if not isinstance(schema, selector.Selector):
return UNSUPPORTED return UNSUPPORTED
@ -782,7 +782,7 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
result["properties"] = { result["properties"] = {
field: convert( field: convert(
selector.selector(field_schema["selector"]), selector.selector(field_schema["selector"]),
custom_serializer=_selector_serializer, custom_serializer=selector_serializer,
) )
for field, field_schema in fields.items() for field, field_schema in fields.items()
} }

View File

@ -112,19 +112,26 @@ async def setup_ha(hass: HomeAssistant) -> None:
@pytest.fixture @pytest.fixture
def mock_send_message_stream() -> Generator[AsyncMock]: def mock_chat_create() -> Generator[AsyncMock]:
"""Mock stream response.""" """Mock stream response."""
async def mock_generator(stream): async def mock_generator(stream):
for value in stream: for value in stream:
yield value yield value
with patch( mock_send_message_stream = AsyncMock()
"google.genai.chats.AsyncChat.send_message_stream",
AsyncMock(),
) as mock_send_message_stream:
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator( mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
mock_send_message_stream.return_value.pop(0) mock_send_message_stream.return_value.pop(0)
) )
yield mock_send_message_stream with patch(
"google.genai.chats.AsyncChats.create",
return_value=AsyncMock(send_message_stream=mock_send_message_stream),
) as mock_create:
yield mock_create
@pytest.fixture
def mock_send_message_stream(mock_chat_create) -> Generator[AsyncMock]:
"""Mock stream response."""
return mock_chat_create.return_value.send_message_stream

View File

@ -4,10 +4,12 @@ from unittest.mock import AsyncMock
from google.genai.types import GenerateContentResponse from google.genai.types import GenerateContentResponse
import pytest import pytest
import voluptuous as vol
from homeassistant.components import ai_task from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.conversation import ( from tests.components.conversation import (
@ -17,14 +19,15 @@ from tests.components.conversation import (
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
async def test_run_task( async def test_generate_data(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811 mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
mock_chat_create: AsyncMock,
entity_registry: er.EntityRegistry, entity_registry: er.EntityRegistry,
) -> None: ) -> None:
"""Test empty response.""" """Test generating data."""
entity_id = "ai_task.google_ai_task" entity_id = "ai_task.google_ai_task"
# Ensure it's linked to the subentry # Ensure it's linked to the subentry
@ -60,3 +63,68 @@ async def test_run_task(
instructions="Test prompt", instructions="Test prompt",
) )
assert result.data == "Hi there!" assert result.data == "Hi there!"
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": '{"characters": ["Mario", "Luigi"]}'}],
"role": "model",
},
}
],
),
],
]
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Give me 2 mario characters",
structure=vol.Schema(
{
vol.Required("characters"): selector.selector(
{
"text": {
"multiple": True,
}
}
)
},
),
)
assert result.data == {"characters": ["Mario", "Luigi"]}
assert len(mock_chat_create.mock_calls) == 2
config = mock_chat_create.mock_calls[-1][2]["config"]
assert config.response_mime_type == "application/json"
assert config.response_schema == {
"properties": {"characters": {"items": {"type": "STRING"}, "type": "ARRAY"}},
"required": ["characters"],
"type": "OBJECT",
}
# Raise error on invalid JSON response
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "INVALID JSON RESPONSE"}],
"role": "model",
},
}
],
),
],
]
with pytest.raises(HomeAssistantError):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
structure=vol.Schema({vol.Required("bla"): str}),
)