From 802ad2ff5120ce55099c8cf81ff4174cdd7fabf3 Mon Sep 17 00:00:00 2001 From: Ivan Lopez Hernandez Date: Thu, 3 Apr 2025 13:23:59 -0700 Subject: [PATCH] Made Google Search enable dependent on Assist availability (#141712) * Made Google Search enable dependent on Assist availability * Show error instead of rendering again * Cleanup test code --- .../config_flow.py | 18 +- .../strings.json | 3 + .../test_config_flow.py | 164 +++++++++++++++--- 3 files changed, 151 insertions(+), 34 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index b7753c21bf9..ac6cb696a7d 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -179,28 +179,30 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow): ) -> ConfigFlowResult: """Manage the options.""" options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options + errors: dict[str, str] = {} if user_input is not None: if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: if user_input[CONF_LLM_HASS_API] == "none": user_input.pop(CONF_LLM_HASS_API) - return self.async_create_entry(title="", data=user_input) + if not ( + user_input.get(CONF_LLM_HASS_API, "none") != "none" + and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True + ): + # Don't allow to save options that enable the Google Seearch tool with an Assist API + return self.async_create_entry(title="", data=user_input) + errors[CONF_USE_GOOGLE_SEARCH_TOOL] = "invalid_google_search_option" # Re-render the options again, now with the recommended options shown/hidden self.last_rendered_recommended = user_input[CONF_RECOMMENDED] - options = { - CONF_RECOMMENDED: user_input[CONF_RECOMMENDED], - CONF_PROMPT: user_input[CONF_PROMPT], - CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API], - } + options = user_input schema = await google_generative_ai_config_option_schema( self.hass, options, self._genai_client ) return self.async_show_form( - step_id="init", - data_schema=vol.Schema(schema), + step_id="init", data_schema=vol.Schema(schema), errors=errors ) diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json index b814f89469a..cd7af711d3a 100644 --- a/homeassistant/components/google_generative_ai_conversation/strings.json +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -43,6 +43,9 @@ "prompt": "Instruct how the LLM should respond. This can be a template." } } + }, + "error": { + "invalid_google_search_option": "Google Search cannot be enabled alongside any Assist capability, this can only be used when Assist is set to \"No control\"." } }, "services": { diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index f7635c0b45e..8fda02b335d 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -39,9 +39,8 @@ from . import CLIENT_ERROR_500, CLIENT_ERROR_API_KEY_INVALID from tests.common import MockConfigEntry -@pytest.fixture -def mock_models(): - """Mock the model list API.""" +def get_models_pager(): + """Return a generator that yields the models.""" model_20_flash = Mock( display_name="Gemini 2.0 Flash", supported_actions=["generateContent"], @@ -72,11 +71,7 @@ def mock_models(): yield model_15_pro yield model_10_pro - with patch( - "google.genai.models.AsyncModels.list", - return_value=models_pager(), - ): - yield + return models_pager() async def test_form(hass: HomeAssistant) -> None: @@ -119,8 +114,13 @@ async def test_form(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 +def will_options_be_rendered_again(current_options, new_options) -> bool: + """Determine if options will be rendered again.""" + return current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED) + + @pytest.mark.parametrize( - ("current_options", "new_options", "expected_options"), + ("current_options", "new_options", "expected_options", "errors"), [ ( { @@ -147,6 +147,7 @@ async def test_form(hass: HomeAssistant) -> None: CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL, }, + None, ), ( { @@ -157,6 +158,7 @@ async def test_form(hass: HomeAssistant) -> None: CONF_TOP_P: RECOMMENDED_TOP_P, CONF_TOP_K: RECOMMENDED_TOP_K, CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_USE_GOOGLE_SEARCH_TOOL: True, }, { CONF_RECOMMENDED: True, @@ -168,6 +170,98 @@ async def test_form(hass: HomeAssistant) -> None: CONF_LLM_HASS_API: "assist", CONF_PROMPT: "", }, + None, + ), + ( + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_USE_GOOGLE_SEARCH_TOOL: RECOMMENDED_USE_GOOGLE_SEARCH_TOOL, + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_USE_GOOGLE_SEARCH_TOOL: True, + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_USE_GOOGLE_SEARCH_TOOL: True, + }, + None, + ), + ( + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_USE_GOOGLE_SEARCH_TOOL: True, + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_LLM_HASS_API: "assist", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_USE_GOOGLE_SEARCH_TOOL: True, + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + CONF_HARASSMENT_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_HATE_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_SEXUAL_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_DANGEROUS_BLOCK_THRESHOLD: RECOMMENDED_HARM_BLOCK_THRESHOLD, + CONF_USE_GOOGLE_SEARCH_TOOL: True, + }, + {CONF_USE_GOOGLE_SEARCH_TOOL: "invalid_google_search_option"}, ), ], ) @@ -175,10 +269,10 @@ async def test_form(hass: HomeAssistant) -> None: async def test_options_switching( hass: HomeAssistant, mock_config_entry: MockConfigEntry, - mock_models, current_options, new_options, expected_options, + errors, ) -> None: """Test the options form.""" with patch("google.genai.models.AsyncModels.get"): @@ -186,24 +280,42 @@ async def test_options_switching( mock_config_entry, options=current_options ) await hass.async_block_till_done() - options_flow = await hass.config_entries.options.async_init( - mock_config_entry.entry_id - ) - if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED): - options_flow = await hass.config_entries.options.async_configure( - options_flow["flow_id"], - { - **current_options, - CONF_RECOMMENDED: new_options[CONF_RECOMMENDED], - }, + with patch( + "google.genai.models.AsyncModels.list", + return_value=get_models_pager(), + ): + options_flow = await hass.config_entries.options.async_init( + mock_config_entry.entry_id + ) + if will_options_be_rendered_again(current_options, new_options): + retry_options = { + **current_options, + CONF_RECOMMENDED: new_options[CONF_RECOMMENDED], + } + with patch( + "google.genai.models.AsyncModels.list", + return_value=get_models_pager(), + ): + options_flow = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + retry_options, + ) + with patch( + "google.genai.models.AsyncModels.list", + return_value=get_models_pager(), + ): + options = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + new_options, ) - options = await hass.config_entries.options.async_configure( - options_flow["flow_id"], - new_options, - ) await hass.async_block_till_done() - assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"] == expected_options + if errors is None: + assert options["type"] is FlowResultType.CREATE_ENTRY + assert options["data"] == expected_options + + else: + assert options["type"] is FlowResultType.FORM + assert options.get("errors", None) == errors @pytest.mark.parametrize(