mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Add Google Search tool in Google Generative AI (#140772)
* Added Google Search grounding * Added testing
This commit is contained in:
parent
af96fedc0f
commit
6a7fa3769d
@ -44,6 +44,7 @@ from .const import (
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
@ -51,6 +52,7 @@ from .const import (
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
|
||||
@ -341,6 +343,13 @@ async def google_generative_ai_config_option_schema(
|
||||
},
|
||||
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
): harm_block_thresholds_selector,
|
||||
vol.Optional(
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
|
||||
},
|
||||
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
return schema
|
||||
|
@ -22,5 +22,7 @@ CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL = "enable_google_search_tool"
|
||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||
|
||||
TIMEOUT_MILLIS = 10000
|
||||
|
@ -12,6 +12,7 @@ from google.genai.types import (
|
||||
Content,
|
||||
FunctionDeclaration,
|
||||
GenerateContentConfig,
|
||||
GoogleSearch,
|
||||
HarmCategory,
|
||||
Part,
|
||||
SafetySetting,
|
||||
@ -39,6 +40,7 @@ from .const import (
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
@ -296,6 +298,13 @@ class GoogleGenerativeAIConversationEntity(
|
||||
for tool in chat_log.llm_api.tools
|
||||
]
|
||||
|
||||
# Using search grounding allows the model to retrieve information from the web,
|
||||
# however, it may interfere with how the model decides to use some tools, or entities
|
||||
# for example weather entity may be disregarded if the model chooses to Google it.
|
||||
if options.get(CONF_USE_GOOGLE_SEARCH_TOOL) is True:
|
||||
tools = tools or []
|
||||
tools.append(Tool(google_search=GoogleSearch()))
|
||||
|
||||
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
|
||||
# Assume future versions will support it (if not, the request fails with a
|
||||
|
@ -36,7 +36,8 @@
|
||||
"harassment_block_threshold": "Negative or harmful comments targeting identity and/or protected attributes",
|
||||
"hate_block_threshold": "Content that is rude, disrespectful, or profane",
|
||||
"sexual_block_threshold": "Contains references to sexual acts or other lewd content",
|
||||
"dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts"
|
||||
"dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts",
|
||||
"enable_google_search_tool": "Enable Google Search tool"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||
|
@ -4,6 +4,9 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -41,6 +44,23 @@ async def mock_config_entry_with_assist(
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_config_entry_with_google_search(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with assist."""
|
||||
with patch("google.genai.models.AsyncModels.get"):
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL: True,
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(
|
||||
hass: HomeAssistant, mock_config_entry: ConfigEntry
|
||||
|
@ -61,3 +61,34 @@
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_use_google_search
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'param1': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.ARRAY: 'ARRAY'>, description='Test parameters', enum=None, format=None, items=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None), properties=None, required=None), 'param2': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=None, description=None, enum=None, format=None, items=None, properties=None, required=None), 'param3': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'json': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None)}, required=[])}, required=[]))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None), Tool(function_declarations=None, retrieval=None, google_search=GoogleSearch(), google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
||||
'history': list([
|
||||
]),
|
||||
'model': 'models/gemini-2.0-flash',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': 'Please call the test function',
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().send_message',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
|
||||
}),
|
||||
),
|
||||
])
|
||||
# ---
|
||||
|
@ -21,12 +21,14 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TOP_K,
|
||||
RECOMMENDED_TOP_P,
|
||||
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -143,6 +145,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
},
|
||||
),
|
||||
(
|
||||
|
@ -176,6 +176,72 @@ async def test_function_call(
|
||||
}
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||
async def test_use_google_search(
|
||||
mock_get_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_google_search: MockConfigEntry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test function calling."""
|
||||
agent_id = "conversation.google_generative_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
mock_tool = AsyncMock()
|
||||
mock_tool.name = "test_tool"
|
||||
mock_tool.description = "Test function"
|
||||
mock_tool.parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional("param1", description="Test parameters"): [
|
||||
vol.All(str, vol.Lower)
|
||||
],
|
||||
vol.Optional("param2"): vol.Any(float, int),
|
||||
vol.Optional("param3"): dict,
|
||||
}
|
||||
)
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
|
||||
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||
mock_chat = AsyncMock()
|
||||
mock_create.return_value.send_message = mock_chat
|
||||
chat_response = Mock(prompt_feedback=None)
|
||||
mock_chat.return_value = chat_response
|
||||
mock_part = Mock()
|
||||
mock_part.text = ""
|
||||
mock_part.function_call = FunctionCall(
|
||||
name="test_tool",
|
||||
args={
|
||||
"param1": ["test_value", "param1\\'s value"],
|
||||
"param2": 2.7,
|
||||
},
|
||||
)
|
||||
|
||||
def tool_call(
|
||||
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
||||
) -> dict[str, Any]:
|
||||
mock_part.function_call = None
|
||||
mock_part.text = "Hi there!"
|
||||
return {"result": "Test response"}
|
||||
|
||||
mock_tool.async_call.side_effect = tool_call
|
||||
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||
await conversation.async_converse(
|
||||
hass,
|
||||
"Please call the test function",
|
||||
None,
|
||||
context,
|
||||
agent_id=agent_id,
|
||||
device_id="test_device",
|
||||
)
|
||||
|
||||
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user