mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Add Google Search tool in Google Generative AI (#140772)
* Added Google Search grounding * Added testing
This commit is contained in:
parent
af96fedc0f
commit
6a7fa3769d
@ -44,6 +44,7 @@ from .const import (
|
|||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
@ -51,6 +52,7 @@ from .const import (
|
|||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
RECOMMENDED_TOP_K,
|
RECOMMENDED_TOP_K,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
|
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||||
TIMEOUT_MILLIS,
|
TIMEOUT_MILLIS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -341,6 +343,13 @@ async def google_generative_ai_config_option_schema(
|
|||||||
},
|
},
|
||||||
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
): harm_block_thresholds_selector,
|
): harm_block_thresholds_selector,
|
||||||
|
vol.Optional(
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
|
description={
|
||||||
|
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
|
||||||
|
},
|
||||||
|
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||||
|
): bool,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return schema
|
return schema
|
||||||
|
@ -22,5 +22,7 @@ CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
|||||||
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
||||||
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
|
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
|
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL = "enable_google_search_tool"
|
||||||
|
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||||
|
|
||||||
TIMEOUT_MILLIS = 10000
|
TIMEOUT_MILLIS = 10000
|
||||||
|
@ -12,6 +12,7 @@ from google.genai.types import (
|
|||||||
Content,
|
Content,
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
|
GoogleSearch,
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
Part,
|
Part,
|
||||||
SafetySetting,
|
SafetySetting,
|
||||||
@ -39,6 +40,7 @@ from .const import (
|
|||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
@ -296,6 +298,13 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
for tool in chat_log.llm_api.tools
|
for tool in chat_log.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Using search grounding allows the model to retrieve information from the web,
|
||||||
|
# however, it may interfere with how the model decides to use some tools, or entities
|
||||||
|
# for example weather entity may be disregarded if the model chooses to Google it.
|
||||||
|
if options.get(CONF_USE_GOOGLE_SEARCH_TOOL) is True:
|
||||||
|
tools = tools or []
|
||||||
|
tools.append(Tool(google_search=GoogleSearch()))
|
||||||
|
|
||||||
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
model_name = self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
|
# Gemini 1.0 doesn't support system_instruction while 1.5 does.
|
||||||
# Assume future versions will support it (if not, the request fails with a
|
# Assume future versions will support it (if not, the request fails with a
|
||||||
|
@ -36,7 +36,8 @@
|
|||||||
"harassment_block_threshold": "Negative or harmful comments targeting identity and/or protected attributes",
|
"harassment_block_threshold": "Negative or harmful comments targeting identity and/or protected attributes",
|
||||||
"hate_block_threshold": "Content that is rude, disrespectful, or profane",
|
"hate_block_threshold": "Content that is rude, disrespectful, or profane",
|
||||||
"sexual_block_threshold": "Contains references to sexual acts or other lewd content",
|
"sexual_block_threshold": "Contains references to sexual acts or other lewd content",
|
||||||
"dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts"
|
"dangerous_block_threshold": "Promotes, facilitates, or encourages harmful acts",
|
||||||
|
"enable_google_search_tool": "Enable Google Search tool"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||||
|
@ -4,6 +4,9 @@ from unittest.mock import Mock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.google_generative_ai_conversation.conversation import (
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -41,6 +44,23 @@ async def mock_config_entry_with_assist(
|
|||||||
return mock_config_entry
|
return mock_config_entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_config_entry_with_google_search(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
|
) -> MockConfigEntry:
|
||||||
|
"""Mock a config entry with assist."""
|
||||||
|
with patch("google.genai.models.AsyncModels.get"):
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={
|
||||||
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL: True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
return mock_config_entry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_init_component(
|
async def mock_init_component(
|
||||||
hass: HomeAssistant, mock_config_entry: ConfigEntry
|
hass: HomeAssistant, mock_config_entry: ConfigEntry
|
||||||
|
@ -61,3 +61,34 @@
|
|||||||
),
|
),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_use_google_search
|
||||||
|
list([
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'config': GenerateContentConfig(http_options=None, system_instruction="Current time is 05:00:00. Today's date is 2024-05-24.\nYou are a voice assistant for Home Assistant.\nAnswer questions about the world truthfully.\nAnswer in plain text. Keep it simple and to the point.\nOnly if the user wants to control a device, tell them to expose entities to their voice assistant in Home Assistant.", temperature=1.0, top_p=0.95, top_k=64.0, candidate_count=None, max_output_tokens=150, stop_sequences=None, response_logprobs=None, logprobs=None, presence_penalty=None, frequency_penalty=None, seed=None, response_mime_type=None, response_schema=None, routing_config=None, safety_settings=[SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HATE_SPEECH: 'HARM_CATEGORY_HATE_SPEECH'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_HARASSMENT: 'HARM_CATEGORY_HARASSMENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 'HARM_CATEGORY_DANGEROUS_CONTENT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>), SafetySetting(method=None, category=<HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 'HARM_CATEGORY_SEXUALLY_EXPLICIT'>, threshold=<HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: 'BLOCK_MEDIUM_AND_ABOVE'>)], tools=[Tool(function_declarations=[FunctionDeclaration(response=None, description='Test function', name='test_tool', parameters=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'param1': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.ARRAY: 'ARRAY'>, description='Test parameters', enum=None, format=None, items=Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None), properties=None, required=None), 'param2': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=None, description=None, enum=None, format=None, items=None, properties=None, required=None), 'param3': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.OBJECT: 'OBJECT'>, description=None, enum=None, format=None, items=None, properties={'json': Schema(min_items=None, example=None, property_ordering=None, pattern=None, minimum=None, default=None, any_of=None, max_length=None, title=None, min_length=None, min_properties=None, max_items=None, maximum=None, nullable=None, max_properties=None, type=<Type.STRING: 'STRING'>, description=None, enum=None, format=None, items=None, properties=None, required=None)}, required=[])}, required=[]))], retrieval=None, google_search=None, google_search_retrieval=None, code_execution=None), Tool(function_declarations=None, retrieval=None, google_search=GoogleSearch(), google_search_retrieval=None, code_execution=None)], tool_config=None, labels=None, cached_content=None, response_modalities=None, media_resolution=None, speech_config=None, audio_timestamp=None, automatic_function_calling=AutomaticFunctionCallingConfig(disable=True, maximum_remote_calls=None, ignore_call_history=None), thinking_config=None),
|
||||||
|
'history': list([
|
||||||
|
]),
|
||||||
|
'model': 'models/gemini-2.0-flash',
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().send_message',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'message': 'Please call the test function',
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().send_message',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'message': Content(parts=[Part(video_metadata=None, thought=None, code_execution_result=None, executable_code=None, file_data=None, function_call=None, function_response=FunctionResponse(id=None, name='test_tool', response={'result': 'Test response'}), inline_data=None, text=None)], role=None),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
@ -21,12 +21,14 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
|||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_TOP_K,
|
RECOMMENDED_TOP_K,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
|
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -143,6 +145,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
|
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -176,6 +176,72 @@ async def test_function_call(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
@pytest.mark.usefixtures("mock_ulid_tools")
|
||||||
|
async def test_use_google_search(
|
||||||
|
mock_get_tools,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry_with_google_search: MockConfigEntry,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test function calling."""
|
||||||
|
agent_id = "conversation.google_generative_ai_conversation"
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
mock_tool = AsyncMock()
|
||||||
|
mock_tool.name = "test_tool"
|
||||||
|
mock_tool.description = "Test function"
|
||||||
|
mock_tool.parameters = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Optional("param1", description="Test parameters"): [
|
||||||
|
vol.All(str, vol.Lower)
|
||||||
|
],
|
||||||
|
vol.Optional("param2"): vol.Any(float, int),
|
||||||
|
vol.Optional("param3"): dict,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_get_tools.return_value = [mock_tool]
|
||||||
|
|
||||||
|
with patch("google.genai.chats.AsyncChats.create") as mock_create:
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_create.return_value.send_message = mock_chat
|
||||||
|
chat_response = Mock(prompt_feedback=None)
|
||||||
|
mock_chat.return_value = chat_response
|
||||||
|
mock_part = Mock()
|
||||||
|
mock_part.text = ""
|
||||||
|
mock_part.function_call = FunctionCall(
|
||||||
|
name="test_tool",
|
||||||
|
args={
|
||||||
|
"param1": ["test_value", "param1\\'s value"],
|
||||||
|
"param2": 2.7,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def tool_call(
|
||||||
|
hass: HomeAssistant, tool_input: llm.ToolInput, tool_context: llm.LLMContext
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
mock_part.function_call = None
|
||||||
|
mock_part.text = "Hi there!"
|
||||||
|
return {"result": "Test response"}
|
||||||
|
|
||||||
|
mock_tool.async_call.side_effect = tool_call
|
||||||
|
chat_response.candidates = [Mock(content=Mock(parts=[mock_part]))]
|
||||||
|
await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"Please call the test function",
|
||||||
|
None,
|
||||||
|
context,
|
||||||
|
agent_id=agent_id,
|
||||||
|
device_id="test_device",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [tuple(mock_call) for mock_call in mock_create.mock_calls] == snapshot
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI._async_get_tools"
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user