mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Add Google Gen AI structured data support (#148143)
This commit is contained in:
parent
33d05d99eb
commit
4f4ec6f41a
@ -2,11 +2,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from json import JSONDecodeError
|
||||||
|
|
||||||
from homeassistant.components import ai_task, conversation
|
from homeassistant.components import ai_task, conversation
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
from .const import LOGGER
|
from .const import LOGGER
|
||||||
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
|
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
|
||||||
@ -42,7 +45,7 @@ class GoogleGenerativeAITaskEntity(
|
|||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
) -> ai_task.GenDataTaskResult:
|
) -> ai_task.GenDataTaskResult:
|
||||||
"""Handle a generate data task."""
|
"""Handle a generate data task."""
|
||||||
await self._async_handle_chat_log(chat_log)
|
await self._async_handle_chat_log(chat_log, task.structure)
|
||||||
|
|
||||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
LOGGER.error(
|
LOGGER.error(
|
||||||
@ -51,7 +54,25 @@ class GoogleGenerativeAITaskEntity(
|
|||||||
)
|
)
|
||||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||||
|
|
||||||
|
text = chat_log.content[-1].content or ""
|
||||||
|
|
||||||
|
if not task.structure:
|
||||||
return ai_task.GenDataTaskResult(
|
return ai_task.GenDataTaskResult(
|
||||||
conversation_id=chat_log.conversation_id,
|
conversation_id=chat_log.conversation_id,
|
||||||
data=chat_log.content[-1].content or "",
|
data=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json_loads(text)
|
||||||
|
except JSONDecodeError as err:
|
||||||
|
LOGGER.error(
|
||||||
|
"Failed to parse JSON response: %s. Response: %s",
|
||||||
|
err,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
raise HomeAssistantError(ERROR_GETTING_RESPONSE) from err
|
||||||
|
|
||||||
|
return ai_task.GenDataTaskResult(
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
data=data,
|
||||||
)
|
)
|
||||||
|
@ -21,6 +21,7 @@ from google.genai.types import (
|
|||||||
Schema,
|
Schema,
|
||||||
Tool,
|
Tool,
|
||||||
)
|
)
|
||||||
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
@ -324,6 +325,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
async def _async_handle_chat_log(
|
async def _async_handle_chat_log(
|
||||||
self,
|
self,
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
|
structure: vol.Schema | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
@ -402,6 +404,18 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
generateContentConfig.automatic_function_calling = (
|
generateContentConfig.automatic_function_calling = (
|
||||||
AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None)
|
AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None)
|
||||||
)
|
)
|
||||||
|
if structure:
|
||||||
|
generateContentConfig.response_mime_type = "application/json"
|
||||||
|
generateContentConfig.response_schema = _format_schema(
|
||||||
|
convert(
|
||||||
|
structure,
|
||||||
|
custom_serializer=(
|
||||||
|
chat_log.llm_api.custom_serializer
|
||||||
|
if chat_log.llm_api
|
||||||
|
else llm.selector_serializer
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if not supports_system_instruction:
|
if not supports_system_instruction:
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -458,7 +458,7 @@ class AssistAPI(API):
|
|||||||
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
|
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
|
||||||
llm_context=llm_context,
|
llm_context=llm_context,
|
||||||
tools=self._async_get_tools(llm_context, exposed_entities),
|
tools=self._async_get_tools(llm_context, exposed_entities),
|
||||||
custom_serializer=_selector_serializer,
|
custom_serializer=selector_serializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@ -701,7 +701,7 @@ def _get_exposed_entities(
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
def selector_serializer(schema: Any) -> Any: # noqa: C901
|
||||||
"""Convert selectors into OpenAPI schema."""
|
"""Convert selectors into OpenAPI schema."""
|
||||||
if not isinstance(schema, selector.Selector):
|
if not isinstance(schema, selector.Selector):
|
||||||
return UNSUPPORTED
|
return UNSUPPORTED
|
||||||
@ -782,7 +782,7 @@ def _selector_serializer(schema: Any) -> Any: # noqa: C901
|
|||||||
result["properties"] = {
|
result["properties"] = {
|
||||||
field: convert(
|
field: convert(
|
||||||
selector.selector(field_schema["selector"]),
|
selector.selector(field_schema["selector"]),
|
||||||
custom_serializer=_selector_serializer,
|
custom_serializer=selector_serializer,
|
||||||
)
|
)
|
||||||
for field, field_schema in fields.items()
|
for field, field_schema in fields.items()
|
||||||
}
|
}
|
||||||
|
@ -112,19 +112,26 @@ async def setup_ha(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_send_message_stream() -> Generator[AsyncMock]:
|
def mock_chat_create() -> Generator[AsyncMock]:
|
||||||
"""Mock stream response."""
|
"""Mock stream response."""
|
||||||
|
|
||||||
async def mock_generator(stream):
|
async def mock_generator(stream):
|
||||||
for value in stream:
|
for value in stream:
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
with patch(
|
mock_send_message_stream = AsyncMock()
|
||||||
"google.genai.chats.AsyncChat.send_message_stream",
|
|
||||||
AsyncMock(),
|
|
||||||
) as mock_send_message_stream:
|
|
||||||
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
|
mock_send_message_stream.side_effect = lambda **kwargs: mock_generator(
|
||||||
mock_send_message_stream.return_value.pop(0)
|
mock_send_message_stream.return_value.pop(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield mock_send_message_stream
|
with patch(
|
||||||
|
"google.genai.chats.AsyncChats.create",
|
||||||
|
return_value=AsyncMock(send_message_stream=mock_send_message_stream),
|
||||||
|
) as mock_create:
|
||||||
|
yield mock_create
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_send_message_stream(mock_chat_create) -> Generator[AsyncMock]:
|
||||||
|
"""Mock stream response."""
|
||||||
|
return mock_chat_create.return_value.send_message_stream
|
||||||
|
@ -4,10 +4,12 @@ from unittest.mock import AsyncMock
|
|||||||
|
|
||||||
from google.genai.types import GenerateContentResponse
|
from google.genai.types import GenerateContentResponse
|
||||||
import pytest
|
import pytest
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import ai_task
|
from homeassistant.components import ai_task
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import entity_registry as er, selector
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
from tests.components.conversation import (
|
from tests.components.conversation import (
|
||||||
@ -17,14 +19,15 @@ from tests.components.conversation import (
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("mock_init_component")
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
async def test_run_task(
|
async def test_generate_data(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
mock_chat_log: MockChatLog, # noqa: F811
|
mock_chat_log: MockChatLog, # noqa: F811
|
||||||
mock_send_message_stream: AsyncMock,
|
mock_send_message_stream: AsyncMock,
|
||||||
|
mock_chat_create: AsyncMock,
|
||||||
entity_registry: er.EntityRegistry,
|
entity_registry: er.EntityRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test empty response."""
|
"""Test generating data."""
|
||||||
entity_id = "ai_task.google_ai_task"
|
entity_id = "ai_task.google_ai_task"
|
||||||
|
|
||||||
# Ensure it's linked to the subentry
|
# Ensure it's linked to the subentry
|
||||||
@ -60,3 +63,68 @@ async def test_run_task(
|
|||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
)
|
)
|
||||||
assert result.data == "Hi there!"
|
assert result.data == "Hi there!"
|
||||||
|
|
||||||
|
mock_send_message_stream.return_value = [
|
||||||
|
[
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [{"text": '{"characters": ["Mario", "Luigi"]}'}],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Give me 2 mario characters",
|
||||||
|
structure=vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required("characters"): selector.selector(
|
||||||
|
{
|
||||||
|
"text": {
|
||||||
|
"multiple": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||||
|
|
||||||
|
assert len(mock_chat_create.mock_calls) == 2
|
||||||
|
config = mock_chat_create.mock_calls[-1][2]["config"]
|
||||||
|
assert config.response_mime_type == "application/json"
|
||||||
|
assert config.response_schema == {
|
||||||
|
"properties": {"characters": {"items": {"type": "STRING"}, "type": "ARRAY"}},
|
||||||
|
"required": ["characters"],
|
||||||
|
"type": "OBJECT",
|
||||||
|
}
|
||||||
|
# Raise error on invalid JSON response
|
||||||
|
mock_send_message_stream.return_value = [
|
||||||
|
[
|
||||||
|
GenerateContentResponse(
|
||||||
|
candidates=[
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [{"text": "INVALID JSON RESPONSE"}],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
with pytest.raises(HomeAssistantError):
|
||||||
|
result = await ai_task.async_generate_data(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Test prompt",
|
||||||
|
structure=vol.Schema({vol.Required("bla"): str}),
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user