Add Google Search tool in Google Generative AI (#140772)

* Added Google Search grounding

* Added testing
This commit is contained in:
Ivan Lopez Hernandez 2025-03-23 22:23:52 -07:00 committed by GitHub
parent af96fedc0f
commit 6a7fa3769d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 142 additions and 1 deletions

View File

@ -44,6 +44,7 @@ from .const import (
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
@ -51,6 +52,7 @@ from .const import (
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
TIMEOUT_MILLIS,
)
@ -341,6 +343,13 @@ async def google_generative_ai_config_option_schema(
},
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
): harm_block_thresholds_selector,
vol.Optional(
CONF_USE_GOOGLE_SEARCH_TOOL,
description={
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
},
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
): bool,
}
)
return schema

View File

@ -22,5 +22,7 @@ CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
CONF_USE_GOOGLE_SEARCH_TOOL = "enable_google_search_tool"
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
TIMEOUT_MILLIS = 10000

View File

@ -12,6 +12,7 @@ from google.genai.types import (
Content,
FunctionDeclaration,
GenerateContentConfig,
GoogleSearch,
HarmCategory,
Part,
SafetySetting,
@ -39,6 +40,7 @@ from .const import (
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
@ -296,6 +298,13 @@ class GoogleGenerativeAIConversationEntity(
for tool in chat_log.llm_api.tools
]
# Using search grounding allows the model to retrieve information from the web,
# however, it may interfere with how the model decides to use some tools, or entities
# for example weather entity may be disregarded if the model chooses to Google it.
if options.get(CONF_USE_GOOGLE_SEARCH_TOOL) is True:
tools = tools or []
tools.append(Tool(google_search=GoogleSearch()))
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
# Assume future versions will support it (if not, the request fails with a

View File

@ -36,7 +36,8 @@
"harassment_block_threshold": "Negative or harmful comments targeting identity and/or protected attributes",
"hate_block_threshold": "Content that is rude, disrespectful, or profane",
"sexual_block_threshold": "Contains references to sexual acts or other lewd content",
"dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts"
"dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts",
"enable_google_search_tool": "Enable Google Search tool"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."

View File

@ -4,6 +4,9 @@ from unittest.mock import Mock, patch
import pytest
from homeassistant.components.google_generative_ai_conversation.conversation import (
CONF_USE_GOOGLE_SEARCH_TOOL,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
@ -41,6 +44,23 @@ async def mock_config_entry_with_assist(
return mock_config_entry
@pytest.fixture
async def mock_config_entry_with_google_search(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry(
mock_config_entry,
options={
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_USE_GOOGLE_SEARCH_TOOL: True,
},
)
await hass.async_block_till_done()
return mock_config_entry
@pytest.fixture
async def mock_init_component(
hass: HomeAssistant, mock_config_entry: ConfigEntry

View File

@ -61,3 +61,34 @@
),
])
# ---
# name: test_use_google_search
list([
tuple(
'',
tuple(
),
dict({
'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'param1': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.ARRAY: 'ARRAY'>, description='Test parameters', enum=None, format=None, items=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None), properties=None, required=None), 'param2': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=None, description=None, enum=None, format=None, items=None, properties=None, required=None), 'param3': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'json': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None)}, required=[])}, required=[]))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None), Tool(function_declarations=None, retrieval=None, google_search=GoogleSearch(), google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
'history': list([
]),
'model': 'models/gemini-2.0-flash',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': 'Please call the test function',
}),
),
tuple(
'().send_message',
tuple(
),
dict({
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
}),
),
])
# ---

View File

@ -21,12 +21,14 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
@ -143,6 +145,7 @@ async def test_form(hass: HomeAssistant) -> None:
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
},
),
(

View File

@ -176,6 +176,72 @@ async def test_function_call(
}
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)
@pytest.mark.usefixtures("mock_init_component")
@pytest.mark.usefixtures("mock_ulid_tools")
async def test_use_google_search(
mock_get_tools,
hass: HomeAssistant,
mock_config_entry_with_google_search: MockConfigEntry,
snapshot: SnapshotAssertion,
) -> None:
"""Test function calling."""
agent_id = "conversation.google_generative_ai_conversation"
context = Context()
mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{
vol.Optional("param1", description="Test parameters"): [
vol.All(str, vol.Lower)
],
vol.Optional("param2"): vol.Any(float, int),
vol.Optional("param3"): dict,
}
)
mock_get_tools.return_value = [mock_tool]
with patch("google.genai.chats.AsyncChats.create") as mock_create:
mock_chat = AsyncMock()
mock_create.return_value.send_message = mock_chat
chat_response = Mock(prompt_feedback=None)
mock_chat.return_value = chat_response
mock_part = Mock()
mock_part.text = ""
mock_part.function_call = FunctionCall(
name="test_tool",
args={
"param1": ["test_value", "param1\\'s value"],
"param2": 2.7,
},
)
def tool_call(
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
) -> dict[str, Any]:
mock_part.function_call = None
mock_part.text = "Hi there!"
return {"result": "Test response"}
mock_tool.async_call.side_effect = tool_call
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
await conversation.async_converse(
hass,
"Please call the test function",
None,
context,
agent_id=agent_id,
device_id="test_device",
)
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
)