Made Google Search enable dependent on Assist availability (#141712)

* Made Google Search enable dependent on Assist availability

* Show error instead of rendering again

* Cleanup test code
This commit is contained in:
Ivan Lopez Hernandez 2025-04-03 13:23:59 -07:00 committed by Franck Nijhof
parent 9070a8d579
commit 802ad2ff51
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
3 changed files with 151 additions and 34 deletions

View File

@ -179,28 +179,30 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
) -> ConfigFlowResult:
"""Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
errors: dict[str, str] = {}
if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
if not (
user_input.get(CONF_LLM_HASS_API, "none") != "none"
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
):
# Don't allow to save options that enable the Google Seearch tool with an Assist API
return self.async_create_entry(title="", data=user_input)
errors[CONF_USE_GOOGLE_SEARCH_TOOL] = "invalid_google_search_option"
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
options = user_input
schema = await google_generative_ai_config_option_schema(
self.hass, options, self._genai_client
)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
step_id="init", data_schema=vol.Schema(schema), errors=errors
)

View File

@ -43,6 +43,9 @@
"prompt": "Instruct how the LLM should respond. This can be a template."
}
}
},
"error": {
"invalid_google_search_option": "Google Search cannot be enabled alongside any Assist capability, this can only be used when Assist is set to \"No control\"."
}
},
"services": {

View File

@ -39,9 +39,8 @@ from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
from tests.common import MockConfigEntry
@pytest.fixture
def mock_models():
"""Mock the model list API."""
def get_models_pager():
"""Return a generator that yields the models."""
model_20_flash = Mock(
display_name="Gemini 2.0 Flash",
supported_actions=["generateContent"],
@ -72,11 +71,7 @@ def mock_models():
yield model_15_pro
yield model_10_pro
with patch(
"google.genai.models.AsyncModels.list",
return_value=models_pager(),
):
yield
return models_pager()
async def test_form(hass: HomeAssistant) -> None:
@ -119,8 +114,13 @@ async def test_form(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1
def will_options_be_rendered_again(current_options, new_options) -> bool:
"""Determine if options will be rendered again."""
return current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED)
@pytest.mark.parametrize(
("current_options", "new_options", "expected_options"),
("current_options", "new_options", "expected_options", "errors"),
[
(
{
@ -147,6 +147,7 @@ async def test_form(hass: HomeAssistant) -> None:
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
},
None,
),
(
{
@ -157,6 +158,7 @@ async def test_form(hass: HomeAssistant) -> None:
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_TOP_K: RECOMMENDED_TOP_K,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_USE_GOOGLE_SEARCH_TOOL: True,
},
{
CONF_RECOMMENDED: True,
@ -168,6 +170,98 @@ async def test_form(hass: HomeAssistant) -> None:
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
None,
),
(
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_TOP_K: RECOMMENDED_TOP_K,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_TOP_K: RECOMMENDED_TOP_K,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: True,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_TOP_K: RECOMMENDED_TOP_K,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: True,
},
None,
),
(
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_TOP_K: RECOMMENDED_TOP_K,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: True,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_LLM_HASS_API: "assist",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_TOP_K: RECOMMENDED_TOP_K,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: True,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_TOP_K: RECOMMENDED_TOP_K,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD,
CONF_USE_GOOGLE_SEARCH_TOOL: True,
},
{CONF_USE_GOOGLE_SEARCH_TOOL: "invalid_google_search_option"},
),
],
)
@ -175,10 +269,10 @@ async def test_form(hass: HomeAssistant) -> None:
async def test_options_switching(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_models,
current_options,
new_options,
expected_options,
errors,
) -> None:
"""Test the options form."""
with patch("google.genai.models.AsyncModels.get"):
@ -186,24 +280,42 @@ async def test_options_switching(
mock_config_entry, options=current_options
)
await hass.async_block_till_done()
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
**current_options,
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
},
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
if will_options_be_rendered_again(current_options, new_options):
retry_options = {
**current_options,
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
}
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
retry_options,
)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
new_options,
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
new_options,
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options
if errors is None:
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options
else:
assert options["type"] is FlowResultType.FORM
assert options.get("errors", None) == errors
@pytest.mark.parametrize(